Skip to content
Snippets Groups Projects
Commit ffec7cab authored by James Stevens's avatar James Stevens
Browse files

added op counter that distinguishes operations, reg counter still in progress

parent 70194c9e
No related branches found
No related tags found
No related merge requests found
......@@ -220,6 +220,144 @@ class ExpressionOpCounter(CombineMapper):
raise NotImplementedError("ExpressionOpCounter encountered slice, "
"map_slice not implemented.")
class ExpressionOpCounter2(CombineMapper):
def __init__(self, knl):
self.knl = knl
from loopy.expression import TypeInferenceMapper
self.type_inf = TypeInferenceMapper(knl)
def combine(self, values):
return sum(values)
def map_constant(self, expr):
return ToCountMap()
map_tagged_variable = map_constant
map_variable = map_constant
#def map_wildcard(self, expr):
# return 0,0
#def map_function_symbol(self, expr):
# return 0,0
map_call = map_constant
# def map_call_with_kwargs(self, expr): # implemented in CombineMapper
def map_subscript(self, expr): # implemented in CombineMapper
return self.rec(expr.index)
# def map_lookup(self, expr): # implemented in CombineMapper
def map_sum(self, expr):
if expr.children:
return ToCountMap(
{(self.type_inf(expr), 'add'): len(expr.children)-1}
) + sum(self.rec(child) for child in expr.children)
else:
return ToCountMap() #TODO when does this happen?
def map_product(self, expr):
from pymbolic.primitives import is_zero
if expr.children:
# Do not count '(-1)* ' (as produced by
# subtraction in pymbolic): Assume this
# gets implemented as a sign flip or
# as subtraction. (Confirmed to be true on
# at least Nvidia 352.30.)
return sum(ToCountMap({(self.type_inf(expr), 'mul'): 1})
+ self.rec(child)
for child in expr.children
if not is_zero(child + 1)) + \
ToCountMap({(self.type_inf(expr), 'mul'): -1})
else:
return ToCountMap() #TODO when does this happen?
def map_quotient(self, expr, *args):
return ToCountMap({(self.type_inf(expr), 'div'): 1}) \
+ self.rec(expr.numerator) \
+ self.rec(expr.denominator)
map_floor_div = map_quotient
map_remainder = map_quotient # implemented in CombineMapper
def map_power(self, expr):
return ToCountMap({(self.type_inf(expr), 'pow'): 1}) \
+ self.rec(expr.base) \
+ self.rec(expr.exponent)
def map_left_shift(self, expr): # implemented in CombineMapper
return ToCountMap({(self.type_inf(expr), 'shift'): 1}) \
+ self.rec(expr.shiftee) \
+ self.rec(expr.shift)
map_right_shift = map_left_shift
def map_bitwise_not(self, expr): # implemented in CombineMapper
return ToCountMap({(self.type_inf(expr), 'bw'): 1}) \
+ self.rec(expr.child)
def map_bitwise_or(self, expr):
# implemented in CombineMapper, maps to map_sum;
return ToCountMap(
{(self.type_inf(expr), 'bw'): len(expr.children)-1}
) + sum(self.rec(child) for child in expr.children)
map_bitwise_xor = map_bitwise_or
# implemented in CombineMapper, maps to map_sum;
map_bitwise_and = map_bitwise_or
# implemented in CombineMapper, maps to map_sum;
def map_comparison(self, expr): # implemented in CombineMapper
return self.rec(expr.left)+self.rec(expr.right)
def map_logical_not(self, expr):
return self.rec(expr.child)
def map_logical_or(self, expr):
return sum(self.rec(child) for child in expr.children)
map_logical_and = map_logical_or
def map_if(self, expr): # implemented in CombineMapper, recurses
warnings.warn("ExpressionOpCounter counting DRAM accesses as "
"sum of if-statement branches.")
return self.rec(expr.condition) + self.rec(expr.then) + self.rec(expr.else_)
def map_if_positive(self, expr): # implemented in FlopCounter
warnings.warn("ExpressionOpCounter counting DRAM accesses as "
"sum of if_pos-statement branches.")
return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_)
def map_min(self, expr):
# implemented in CombineMapper, maps to map_sum;
return ToCountMap(
{(self.type_inf(expr), 'maxmin'): len(expr.children)-1}
) + sum(self.rec(child) for child in expr.children)
# implemented in CombineMapper, maps to map_sum; # TODO test
map_max = map_min # implemented in CombineMapper, maps to map_sum; # TODO test
def map_common_subexpression(self, expr):
raise NotImplementedError("ExpressionOpCounter encountered "
"common_subexpression, "
"map_common_subexpression not implemented.")
def map_substitution(self, expr):
raise NotImplementedError("ExpressionOpCounter encountered substitution, "
"map_substitution not implemented.")
def map_derivative(self, expr):
raise NotImplementedError("ExpressionOpCounter encountered derivative, "
"map_derivative not implemented.")
def map_slice(self, expr):
raise NotImplementedError("ExpressionOpCounter encountered slice, "
"map_slice not implemented.")
class GlobalSubscriptCounter(CombineMapper):
......@@ -417,6 +555,14 @@ class RegisterUsageEstimator(CombineMapper):
else:
self.vars_found.append(expr)
print("new var found: ", expr)
print("knl.temp_vars: \n", self.knl.temporary_variables)
print("found in temp_vars? ", expr.name in self.knl.temporary_variables)
print("found in inames? ", expr.name in self.knl.all_inames)
#print("knl.vars: \n", self.knl.variables)
if expr.name in self.knl.temporary_variables:
print("local? ", self.knl.temporary_variables[expr.name].is_local)
#print("local? ", self.knl.temporary_variables[expr.name].is_local)
if "_dim_" in str(expr): #TODO how to remove block/thread size/id vars?
return 0
else:
......@@ -602,6 +748,24 @@ def get_op_poly(knl):
return op_poly.dict
def get_op_poly2(knl):
from loopy.preprocess import preprocess_kernel, infer_unknown_types
knl = infer_unknown_types(knl, expect_completion=True)
knl = preprocess_kernel(knl)
op_poly = ToCountMap()
op_counter = ExpressionOpCounter2(knl)
for insn in knl.instructions:
# how many times is this instruction executed?
# check domain size:
insn_inames = knl.insn_inames(insn)
inames_domain = knl.get_inames_domain(insn_inames)
domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
ops = op_counter(insn.assignee) + op_counter(insn.expression)
op_poly = op_poly + ops*count(knl, domain)
return op_poly.dict
def get_gmem_access_poly(knl): # for now just counting subscripts
"""Count the number of global memory accesses in a loopy kernel.
......
......@@ -28,7 +28,7 @@ from pyopencl.tools import ( # noqa
as pytest_generate_tests)
import loopy as lp
from loopy.statistics import get_op_poly, get_gmem_access_poly, get_barrier_poly
from loopy.statistics import get_regs_per_thread
from loopy.statistics import get_op_poly2, get_regs_per_thread
import numpy as np
......@@ -50,9 +50,10 @@ def test_op_counter_basic():
n = 512
m = 256
l = 128
f32 = poly[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
f64 = poly[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
i32 = poly[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
params = {'n': n, 'm': m, 'l': l}
f32 = poly[np.dtype(np.float32)].eval_with_dict(params)
f64 = poly[np.dtype(np.float64)].eval_with_dict(params)
i32 = poly[np.dtype(np.int32)].eval_with_dict(params)
assert f32 == 3*n*m*l
assert f64 == n*m
assert i32 == n*m*2
......@@ -72,7 +73,8 @@ def test_op_counter_reduction():
n = 512
m = 256
l = 128
f32 = poly[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
params = {'n': n, 'm': m, 'l': l}
f32 = poly[np.dtype(np.float32)].eval_with_dict(params)
assert f32 == 2*n*m*l
......@@ -92,9 +94,10 @@ def test_op_counter_logic():
n = 512
m = 256
l = 128
f32 = poly[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
f64 = poly[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
i32 = poly[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
params = {'n': n, 'm': m, 'l': l}
f32 = poly[np.dtype(np.float32)].eval_with_dict(params)
f64 = poly[np.dtype(np.float64)].eval_with_dict(params)
i32 = poly[np.dtype(np.int32)].eval_with_dict(params)
assert f32 == n*m
assert f64 == 3*n*m
assert i32 == n*m
......@@ -118,9 +121,10 @@ def test_op_counter_specialops():
n = 512
m = 256
l = 128
f32 = poly[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
f64 = poly[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
i32 = poly[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
params = {'n': n, 'm': m, 'l': l}
f32 = poly[np.dtype(np.float32)].eval_with_dict(params)
f64 = poly[np.dtype(np.float64)].eval_with_dict(params)
i32 = poly[np.dtype(np.int32)].eval_with_dict(params)
assert f32 == 4*n*m*l
assert f64 == 3*n*m
assert i32 == n*m
......@@ -147,8 +151,9 @@ def test_op_counter_bitwise():
n = 512
m = 256
l = 128
i32 = poly[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
i64 = poly[np.dtype(np.int64)].eval_with_dict({'n': n, 'm': m, 'l': l}) # noqa
params = {'n': n, 'm': m, 'l': l}
i32 = poly[np.dtype(np.int32)].eval_with_dict(params)
i64 = poly[np.dtype(np.int64)].eval_with_dict(params) # noqa
assert np.dtype(np.float64) not in poly
assert i32 == n*m+3*n*m*l
assert i64 == 6*n*m
......@@ -185,6 +190,152 @@ def test_op_counter_triangular_domain():
assert flops == 78
def test_op_counter2_basic():
knl = lp.make_kernel(
"[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
[
"""
c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
e[i, k+1] = -g[i,k]*h[i,k+1]
"""
],
name="basic", assumptions="n,m,l >= 1")
knl = lp.add_and_infer_dtypes(knl,
dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
poly = get_op_poly2(knl)
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f32add = poly[(np.dtype(np.float32), 'add')].eval_with_dict(params)
f32mul = poly[(np.dtype(np.float32), 'mul')].eval_with_dict(params)
f32div = poly[(np.dtype(np.float32), 'div')].eval_with_dict(params)
f64mul = poly[(np.dtype(np.float64), 'mul')].eval_with_dict(params)
i32add = poly[(np.dtype(np.int32), 'add')].eval_with_dict(params)
assert f32add == f32mul == f32div == n*m*l
assert f64mul == n*m
assert i32add == n*m*2
def test_op_counter2_reduction():
knl = lp.make_kernel(
"{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
[
"c[i, j] = sum(k, a[i, k]*b[k, j])"
],
name="matmul_serial", assumptions="n,m,l >= 1")
knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
poly = get_op_poly2(knl)
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f32add = poly[(np.dtype(np.float32), 'add')].eval_with_dict(params)
f32mul = poly[(np.dtype(np.float32), 'mul')].eval_with_dict(params)
assert f32add == f32mul == n*m*l
def test_op_counter2_logic():
knl = lp.make_kernel(
"{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
[
"""
e[i,k] = if(not(k<l-2) and k>6 or k/2==l, g[i,k]*2, g[i,k]+h[i,k]/2)
"""
],
name="logic", assumptions="n,m,l >= 1")
knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64))
poly = get_op_poly2(knl)
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f32mul = poly[(np.dtype(np.float32), 'mul')].eval_with_dict(params)
f64add = poly[(np.dtype(np.float64), 'add')].eval_with_dict(params)
f64div = poly[(np.dtype(np.float64), 'div')].eval_with_dict(params)
i32add = poly[(np.dtype(np.int32), 'add')].eval_with_dict(params)
#f32 = poly[np.dtype(np.float32)].eval_with_dict(params)
#f64 = poly[np.dtype(np.float64)].eval_with_dict(params)
#i32 = poly[np.dtype(np.int32)].eval_with_dict(params)
assert f32mul == n*m
#assert f64 == 3*n*m
assert f64div == 2*n*m #TODO why?
assert f64add == n*m
assert i32add == n*m
def test_op_counter2_specialops():
knl = lp.make_kernel(
"{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
[
"""
c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0)
e[i, k] = (1+g[i,k])**(1+h[i,k+1])
"""
],
name="specialops", assumptions="n,m,l >= 1")
knl = lp.add_and_infer_dtypes(knl,
dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
poly = get_op_poly2(knl)
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f32mul = poly[(np.dtype(np.float32), 'mul')].eval_with_dict(params)
f32div = poly[(np.dtype(np.float32), 'div')].eval_with_dict(params)
f32add = poly[(np.dtype(np.float32), 'add')].eval_with_dict(params)
f64pow = poly[(np.dtype(np.float64), 'pow')].eval_with_dict(params)
f64add = poly[(np.dtype(np.float64), 'add')].eval_with_dict(params)
i32add = poly[(np.dtype(np.int32), 'add')].eval_with_dict(params)
assert f32div == 2*n*m*l
assert f32mul == f32add == n*m*l
assert f64add == 2*n*m
assert f64pow == i32add == n*m
def test_op_counter2_bitwise():
knl = lp.make_kernel(
"{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
[
"""
c[i, j, k] = (a[i,j,k] | 1) + (b[i,j,k] & 1)
e[i, k] = (g[i,k] ^ k)*(~h[i,k+1]) + (g[i, k] << (h[i,k] >> k))
"""
],
name="bitwise", assumptions="n,m,l >= 1")
knl = lp.add_and_infer_dtypes(
knl, dict(
a=np.int32, b=np.int32,
g=np.int64, h=np.int64))
poly = get_op_poly2(knl)
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
i32add = poly[(np.dtype(np.int32), 'add')].eval_with_dict(params)
i32bw = poly[(np.dtype(np.int32), 'bw')].eval_with_dict(params)
i64bw = poly[(np.dtype(np.int64), 'bw')].eval_with_dict(params)
i64mul = poly[(np.dtype(np.int64), 'mul')].eval_with_dict(params)
i64add = poly[(np.dtype(np.int64), 'add')].eval_with_dict(params)
i64shift = poly[(np.dtype(np.int64), 'shift')].eval_with_dict(params)
assert i32add == n*m+n*m*l
assert i32bw == 2*n*m*l
assert i64bw == 2*n*m
assert i64add == i64mul == n*m
assert i64shift == 2*n*m
def test_gmem_access_counter_basic():
knl = lp.make_kernel(
......@@ -203,21 +354,22 @@ def test_gmem_access_counter_basic():
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f32 = poly[
(np.dtype(np.float32), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f64 = poly[
(np.dtype(np.float64), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f32 == 3*n*m*l
assert f64 == 2*n*m
f32 = poly[
(np.dtype(np.float32), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f64 = poly[
(np.dtype(np.float64), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f32 == n*m*l
assert f64 == n*m
......@@ -236,14 +388,15 @@ def test_gmem_access_counter_reduction():
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f32 = poly[
(np.dtype(np.float32), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f32 == 2*n*m*l
f32 = poly[
(np.dtype(np.float32), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f32 == n*l
......@@ -263,18 +416,19 @@ def test_gmem_access_counter_logic():
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f32 = poly[
(np.dtype(np.float32), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f64 = poly[
(np.dtype(np.float64), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f32 == 2*n*m
assert f64 == n*m
f64 = poly[
(np.dtype(np.float64), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f64 == n*m
......@@ -296,21 +450,22 @@ def test_gmem_access_counter_specialops():
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f32 = poly[
(np.dtype(np.float32), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f64 = poly[
(np.dtype(np.float64), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f32 == 2*n*m*l
assert f64 == 2*n*m
f32 = poly[
(np.dtype(np.float32), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f64 = poly[
(np.dtype(np.float64), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f32 == n*m*l
assert f64 == n*m
......@@ -336,14 +491,15 @@ def test_gmem_access_counter_bitwise():
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
i32 = poly[
(np.dtype(np.int32), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert i32 == 4*n*m+2*n*m*l
i32 = poly[
(np.dtype(np.int32), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert i32 == n*m+n*m*l
......@@ -367,21 +523,22 @@ def test_gmem_access_counter_mixed():
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f64uniform = poly[
(np.dtype(np.float64), 'uniform', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f32nonconsec = poly[
(np.dtype(np.float32), 'nonconsecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f64uniform == 2*n*m
assert f32nonconsec == 3*n*m*l
f64uniform = poly[
(np.dtype(np.float64), 'uniform', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f32nonconsec = poly[
(np.dtype(np.float32), 'nonconsecutive', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f64uniform == n*m
assert f32nonconsec == n*m*l
......@@ -406,21 +563,22 @@ def test_gmem_access_counter_nonconsec():
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f64nonconsec = poly[
(np.dtype(np.float64), 'nonconsecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f32nonconsec = poly[
(np.dtype(np.float32), 'nonconsecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f64nonconsec == 2*n*m
assert f32nonconsec == 3*n*m*l
f64nonconsec = poly[
(np.dtype(np.float64), 'nonconsecutive', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f32nonconsec = poly[
(np.dtype(np.float32), 'nonconsecutive', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f64nonconsec == n*m
assert f32nonconsec == n*m*l
......@@ -444,22 +602,23 @@ def test_gmem_access_counter_consec():
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
f64consec = poly[
(np.dtype(np.float64), 'consecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f32consec = poly[
(np.dtype(np.float32), 'consecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f64consec == 2*n*m
assert f32consec == 3*n*m*l
f64consec = poly[
(np.dtype(np.float64), 'consecutive', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f32consec = poly[
(np.dtype(np.float32), 'consecutive', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f64consec == n*m
assert f32consec == n*m*l
......@@ -482,7 +641,8 @@ def test_barrier_counter_nobarriers():
n = 512
m = 256
l = 128
barrier_count = poly.eval_with_dict({'n': n, 'm': m, 'l': l})
params = {'n': n, 'm': m, 'l': l}
barrier_count = poly.eval_with_dict(params)
assert barrier_count == 0
......@@ -507,10 +667,11 @@ def test_barrier_counter_barriers():
n = 512
m = 256
l = 128
barrier_count = poly.eval_with_dict({'n': n, 'm': m, 'l': l})
params = {'n': n, 'm': m, 'l': l}
barrier_count = poly.eval_with_dict(params)
assert barrier_count == 50*10*2
'''
def test_reg_counter_basic():
knl = lp.make_kernel(
......@@ -526,7 +687,7 @@ def test_reg_counter_basic():
knl = lp.add_and_infer_dtypes(knl,
dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
regs = get_regs_per_thread(knl)
#1/0
1/0
assert regs == 6
......@@ -596,7 +757,7 @@ def test_reg_counter_bitwise():
g=np.int64, h=np.int64))
regs = get_regs_per_thread(knl)
assert regs == 6
'''
def test_all_counters_parallel_matmul():
......@@ -613,17 +774,18 @@ def test_all_counters_parallel_matmul():
n = 512
m = 256
l = 128
params = {'n': n, 'm': m, 'l': l}
barrier_count = get_barrier_poly(knl).eval_with_dict({'n': n, 'm': m, 'l': l})
barrier_count = get_barrier_poly(knl).eval_with_dict(params)
assert barrier_count == 0
op_map = get_op_poly(knl)
f32ops = op_map[
np.dtype(np.float32)
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
i32ops = op_map[
np.dtype(np.int32)
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f32ops == n*m*l*2
assert i32ops == n*m*l*4 + l*n*4
......@@ -631,17 +793,17 @@ def test_all_counters_parallel_matmul():
subscript_map = get_gmem_access_poly(knl)
f32uncoal = subscript_map[
(np.dtype(np.float32), 'nonconsecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
f32coal = subscript_map[
(np.dtype(np.float32), 'consecutive', 'load')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f32uncoal == n*m*l
assert f32coal == n*m*l
f32coal = subscript_map[
(np.dtype(np.float32), 'consecutive', 'store')
].eval_with_dict({'n': n, 'm': m, 'l': l})
].eval_with_dict(params)
assert f32coal == n*l
'''
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment