Skip to content
Snippets Groups Projects
Commit 5d1b35de authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Update mem op counter to keep loads and stores apart

parent 39feabf8
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,8 @@ from pymbolic.mapper import CombineMapper
from functools import reduce
class TypeToCountMap:
class ToCountMap:
"""Maps any type of key to an arithmetic type."""
def __init__(self, init_dict=None):
if init_dict is None:
......@@ -46,24 +47,24 @@ class TypeToCountMap:
for k, v in six.iteritems(other.dict):
result[k] = self.dict.get(k, 0) + v
return TypeToCountMap(result)
return ToCountMap(result)
def __radd__(self, other):
if other != 0:
raise ValueError("TypeToCountMap: Attempted to add TypeToCountMap "
"to {} {}. TypeToCountMap may only be added to "
"0 and other TypeToCountMap objects."
raise ValueError("ToCountMap: Attempted to add ToCountMap "
"to {} {}. ToCountMap may only be added to "
"0 and other ToCountMap objects."
.format(type(other), other))
return
return self
def __mul__(self, other):
if isinstance(other, isl.PwQPolynomial):
return TypeToCountMap({index: self.dict[index]*other
return ToCountMap({index: self.dict[index]*other
for index in self.dict.keys()})
else:
raise ValueError("TypeToCountMap: Attempted to multiply "
"TypeToCountMap by {} {}."
raise ValueError("ToCountMap: Attempted to multiply "
"ToCountMap by {} {}."
.format(type(other), other))
__rmul__ = __mul__
......@@ -77,6 +78,9 @@ class TypeToCountMap:
def __str__(self):
return str(self.dict)
def __repr__(self):
return repr(self.dict)
class ExpressionOpCounter(CombineMapper):
......@@ -89,7 +93,7 @@ class ExpressionOpCounter(CombineMapper):
return sum(values)
def map_constant(self, expr):
return TypeToCountMap()
return ToCountMap()
map_tagged_variable = map_constant
map_variable = map_constant
......@@ -111,16 +115,16 @@ class ExpressionOpCounter(CombineMapper):
def map_sum(self, expr):
if expr.children:
return TypeToCountMap(
return ToCountMap(
{self.type_inf(expr): len(expr.children)-1}
) + sum(self.rec(child) for child in expr.children)
else:
return TypeToCountMap()
return ToCountMap()
map_product = map_sum
def map_quotient(self, expr, *args):
return TypeToCountMap({self.type_inf(expr): 1}) \
return ToCountMap({self.type_inf(expr): 1}) \
+ self.rec(expr.numerator) \
+ self.rec(expr.denominator)
......@@ -128,24 +132,24 @@ class ExpressionOpCounter(CombineMapper):
map_remainder = map_quotient # implemented in CombineMapper
def map_power(self, expr):
return TypeToCountMap({self.type_inf(expr): 1}) \
return ToCountMap({self.type_inf(expr): 1}) \
+ self.rec(expr.base) \
+ self.rec(expr.exponent)
def map_left_shift(self, expr): # implemented in CombineMapper
return TypeToCountMap({self.type_inf(expr): 1}) \
return ToCountMap({self.type_inf(expr): 1}) \
+ self.rec(expr.shiftee) \
+ self.rec(expr.shift)
map_right_shift = map_left_shift
def map_bitwise_not(self, expr): # implemented in CombineMapper
return TypeToCountMap({self.type_inf(expr): 1}) \
return ToCountMap({self.type_inf(expr): 1}) \
+ self.rec(expr.child)
def map_bitwise_or(self, expr):
# implemented in CombineMapper, maps to map_sum;
return TypeToCountMap(
return ToCountMap(
{self.type_inf(expr): len(expr.children)-1}
) + sum(self.rec(child) for child in expr.children)
......@@ -210,7 +214,7 @@ class ExpressionSubscriptCounter(CombineMapper):
return sum(values)
def map_constant(self, expr):
return TypeToCountMap()
return ToCountMap()
map_tagged_variable = map_constant
map_variable = map_constant
......@@ -249,13 +253,13 @@ class ExpressionSubscriptCounter(CombineMapper):
if not local_id_found:
# count as uniform access
return TypeToCountMap(
return ToCountMap(
{(self.type_inf(expr), 'uniform'): 1}
) + self.rec(expr.index)
if local_id0 is None:
# only non-zero local id(s) found, assume non-consecutive access
return TypeToCountMap(
return ToCountMap(
{(self.type_inf(expr), 'nonconsecutive'): 1}
) + self.rec(expr.index)
......@@ -274,7 +278,7 @@ class ExpressionSubscriptCounter(CombineMapper):
if coeff_id0 != 1:
# non-consecutive access
return TypeToCountMap(
return ToCountMap(
{(self.type_inf(expr), 'nonconsecutive'): 1}
) + self.rec(expr.index)
......@@ -287,14 +291,14 @@ class ExpressionSubscriptCounter(CombineMapper):
if stride != 1:
# non-consecutive
return TypeToCountMap(
return ToCountMap(
{(self.type_inf(expr), 'nonconsecutive'): 1}
) + self.rec(expr.index)
# else, stride == 1, continue since another idx could contain id0
# loop finished without returning, stride==1 for every instance of local_id0
return TypeToCountMap(
return ToCountMap(
{(self.type_inf(expr), 'consecutive'): 1}
) + self.rec(expr.index)
......@@ -302,7 +306,7 @@ class ExpressionSubscriptCounter(CombineMapper):
if expr.children:
return sum(self.rec(child) for child in expr.children)
else:
return TypeToCountMap()
return ToCountMap()
map_product = map_sum
......@@ -413,7 +417,7 @@ def get_op_poly(knl):
insn_inames = knl.insn_inames(insn)
inames_domain = knl.get_inames_domain(insn_inames)
domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
ops = op_counter(insn.expression)
ops = op_counter(insn.assignee) + op_counter(insn.expression)
op_poly = op_poly + ops*count(knl, domain)
return op_poly
......@@ -429,8 +433,18 @@ def get_DRAM_access_poly(knl): # for now just counting subscripts
insn_inames = knl.insn_inames(insn)
inames_domain = knl.get_inames_domain(insn_inames)
domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
subs = subscript_counter(insn.expression)
subs_poly = subs_poly + subs*count(knl, domain)
subs_expr = subscript_counter(insn.expression)
subs_expr = ToCountMap(dict(
(key + ("load",), val)
for key, val in six.iteritems(subs_expr.dict)))
subs_assignee = subscript_counter(insn.assignee)
subs_assignee = ToCountMap(dict(
(key + ("store",), val)
for key, val in six.iteritems(subs_assignee.dict)))
subs_poly = subs_poly + (subs_expr + subs_assignee)*count(knl, domain)
return subs_poly
......
......@@ -2079,6 +2079,8 @@ def test_vectorize(ctx_factory):
knl = lp.add_and_infer_dtypes(knl, dict(b=np.float32))
knl = lp.split_arg_axis(knl, [("a", 0), ("b", 0)], 4,
split_kwargs=dict(slabs=(0, 1)))
print(knl)
1/0
knl = lp.tag_data_axes(knl, "a,b", "c,vec")
ref_knl = knl
......
......@@ -38,7 +38,7 @@ def test_op_counter_basic():
[
"""
c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
e[i, k] = g[i,k]*h[i,k+1]
e[i, k+1] = g[i,k]*h[i,k+1]
"""
],
name="basic", assumptions="n,m,l >= 1")
......@@ -54,7 +54,7 @@ def test_op_counter_basic():
i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f32 == 3*n*m*l
assert f64 == n*m
assert i32 == n*m
assert i32 == n*m*2
def test_op_counter_reduction():
......@@ -204,14 +204,23 @@ def test_DRAM_access_counter_basic():
m = 256
l = 128
f32 = poly.dict[
(np.dtype(np.float32), 'uniform')
(np.dtype(np.float32), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
f64 = poly.dict[
(np.dtype(np.float64), 'uniform')
(np.dtype(np.float64), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f32 == 3*n*m*l
assert f64 == 2*n*m
f32 = poly.dict[
(np.dtype(np.float32), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
f64 = poly.dict[
(np.dtype(np.float64), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f32 == n*m*l
assert f64 == n*m
def test_DRAM_access_counter_reduction():
......@@ -228,10 +237,15 @@ def test_DRAM_access_counter_reduction():
m = 256
l = 128
f32 = poly.dict[
(np.dtype(np.float32), 'uniform')
(np.dtype(np.float32), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f32 == 2*n*m*l
f32 = poly.dict[
(np.dtype(np.float32), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f32 == n*l
def test_DRAM_access_counter_logic():
......@@ -250,14 +264,19 @@ def test_DRAM_access_counter_logic():
m = 256
l = 128
f32 = poly.dict[
(np.dtype(np.float32), 'uniform')
(np.dtype(np.float32), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
f64 = poly.dict[
(np.dtype(np.float64), 'uniform')
(np.dtype(np.float64), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f32 == 2*n*m
assert f64 == n*m
f64 = poly.dict[
(np.dtype(np.float64), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f64 == n*m
def test_DRAM_access_counter_specialops():
......@@ -278,14 +297,23 @@ def test_DRAM_access_counter_specialops():
m = 256
l = 128
f32 = poly.dict[
(np.dtype(np.float32), 'uniform')
(np.dtype(np.float32), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
f64 = poly.dict[
(np.dtype(np.float64), 'uniform')
(np.dtype(np.float64), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f32 == 2*n*m*l
assert f64 == 2*n*m
f32 = poly.dict[
(np.dtype(np.float32), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
f64 = poly.dict[
(np.dtype(np.float64), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f32 == n*m*l
assert f64 == n*m
def test_DRAM_access_counter_bitwise():
......@@ -309,10 +337,15 @@ def test_DRAM_access_counter_bitwise():
m = 256
l = 128
i32 = poly.dict[
(np.dtype(np.int32), 'uniform')
(np.dtype(np.int32), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert i32 == 4*n*m+2*n*m*l
i32 = poly.dict[
(np.dtype(np.int32), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert i32 == n*m+n*m*l
def test_DRAM_access_counter_mixed():
......@@ -335,14 +368,23 @@ def test_DRAM_access_counter_mixed():
m = 256
l = 128
f64uniform = poly.dict[
(np.dtype(np.float64), 'uniform')
(np.dtype(np.float64), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
f32nonconsec = poly.dict[
(np.dtype(np.float32), 'nonconsecutive')
(np.dtype(np.float32), 'nonconsecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f64uniform == 2*n*m
assert f32nonconsec == 3*n*m*l
f64uniform = poly.dict[
(np.dtype(np.float64), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
f32nonconsec = poly.dict[
(np.dtype(np.float32), 'nonconsecutive', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f64uniform == n*m
assert f32nonconsec == n*m*l
def test_DRAM_access_counter_nonconsec():
......@@ -365,10 +407,10 @@ def test_DRAM_access_counter_nonconsec():
m = 256
l = 128
f64nonconsec = poly.dict[
(np.dtype(np.float64), 'nonconsecutive')
(np.dtype(np.float64), 'nonconsecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
f32nonconsec = poly.dict[
(np.dtype(np.float32), 'nonconsecutive')
(np.dtype(np.float32), 'nonconsecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f64nonconsec == 2*n*m
assert f32nonconsec == 3*n*m*l
......@@ -395,10 +437,10 @@ def test_DRAM_access_counter_consec():
l = 128
print(poly)
f64consec = poly.dict[
(np.dtype(np.float64), 'consecutive')
(np.dtype(np.float64), 'consecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
f32consec = poly.dict[
(np.dtype(np.float32), 'consecutive')
(np.dtype(np.float32), 'consecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f64consec == 2*n*m
assert f32consec == 3*n*m*l
......@@ -468,6 +510,7 @@ def test_all_counters_parallel_matmul():
l = 128
barrier_count = get_barrier_poly(knl).eval_with_dict({'n': n, 'm': n, 'l': n})
assert barrier_count == 0
op_map = get_op_poly(knl)
f32ops = op_map.dict[
......@@ -477,20 +520,25 @@ def test_all_counters_parallel_matmul():
np.dtype(np.int32)
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f32ops == n*m*l*2
assert i32ops == n*m*l*4 + l*n*4
subscript_map = get_DRAM_access_poly(knl)
f32uncoal = subscript_map.dict[
(np.dtype(np.float32), 'nonconsecutive')
(np.dtype(np.float32), 'nonconsecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
f32coal = subscript_map.dict[
(np.dtype(np.float32), 'consecutive')
(np.dtype(np.float32), 'consecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert barrier_count == 0
assert f32ops == n*m*l*2
assert i32ops == n*m*l*4
assert f32uncoal == n*m*l
assert f32coal == n*m*l
f32coal = subscript_map.dict[
(np.dtype(np.float32), 'consecutive', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
assert f32coal == n*l
if __name__ == "__main__":
if len(sys.argv) > 1:
......@@ -498,4 +546,3 @@ if __name__ == "__main__":
else:
from py.test.cmdline import main
main([__file__])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment