diff --git a/loopy/statistics.py b/loopy/statistics.py index 429a6a2116ab111f6cda85a95afe3255d73014bd..b9c0d701f1b06ecd69c1da402eb0c0adb0759dab 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -377,6 +377,108 @@ class GlobalSubscriptCounter(CombineMapper): raise NotImplementedError("GlobalSubscriptCounter encountered slice, " "map_slice not implemented.") +class RegisterUsageEstimator(CombineMapper): + + def __init__(self, knl): + self.knl = knl + from loopy.expression import TypeInferenceMapper + self.type_inf = TypeInferenceMapper(knl) + + def combine(self, values): + return sum(values) + + def map_constant(self, expr): + return 1 + + map_tagged_variable = map_constant + map_variable = map_constant + map_call = map_constant # TODO what is this? + + def map_subscript(self, expr): + name = expr.aggregate.name # name of array + + if name in self.knl.arg_dict: + array = self.knl.arg_dict[name] + else: + # this is a temporary variable + return 1 # TODO +self.rec(expr.index)? + + if not isinstance(array, lp.GlobalArg): + # this array is not in global memory + return 0 # TODO is this right? recurse on index? + + # this is a global mem access + return 1 # TODO +self.rec(expr.index)? + + def map_sum(self, expr): + if expr.children: + return sum(self.rec(child) for child in expr.children) + else: + return 0 # TODO when does this happen? + + map_product = map_sum + + def map_quotient(self, expr, *args): + return self.rec(expr.numerator) + self.rec(expr.denominator) + + map_floor_div = map_quotient + map_remainder = map_quotient + + def map_power(self, expr): + return self.rec(expr.base) + self.rec(expr.exponent) + + def map_left_shift(self, expr): + return self.rec(expr.shiftee)+self.rec(expr.shift) + + map_right_shift = map_left_shift + + def map_bitwise_not(self, expr): + return self.rec(expr.child) + + def map_bitwise_or(self, expr): + return sum(self.rec(child) for child in expr.children) + + map_bitwise_xor = map_bitwise_or + map_bitwise_and = map_bitwise_or + + def map_comparison(self, expr): + return self.rec(expr.left)+self.rec(expr.right) + + map_logical_not = map_bitwise_not + map_logical_or = map_bitwise_or + map_logical_and = map_logical_or + + def map_if(self, expr): + warnings.warn("RegisterUsageEstimator counting register usage as " + "sum of if-statement branches.") + return self.rec(expr.condition) + self.rec(expr.then) + self.rec(expr.else_) + + def map_if_positive(self, expr): + warnings.warn("RegisterUsageEstimator counting register usage as " + "sum of if_pos-statement branches.") + return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_) + + map_min = map_bitwise_or + map_max = map_min + + def map_common_subexpression(self, expr): + raise NotImplementedError("GlobalSubscriptCounter encountered " + "common_subexpression, " + "map_common_subexpression not implemented.") + + def map_substitution(self, expr): + raise NotImplementedError("GlobalSubscriptCounter encountered " + "substitution, " + "map_substitution not implemented.") + + def map_derivative(self, expr): + raise NotImplementedError("GlobalSubscriptCounter encountered " + "derivative, " + "map_derivative not implemented.") + + def map_slice(self, expr): + raise NotImplementedError("GlobalSubscriptCounter encountered slice, " + "map_slice not implemented.") def count(kernel, bset): try: @@ -596,3 +698,26 @@ def get_barrier_poly(knl): barrier_poly += isl.PwQPolynomial('{ 1 }') return barrier_poly + + +def get_regs_per_thread(knl): + + """Estimate registers per thread usage by a loopy kernel. + + :parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated. + + """ + + from loopy.preprocess import preprocess_kernel, infer_unknown_types + from loopy.schedule import RunInstruction + knl = infer_unknown_types(knl, expect_completion=True) + knl = preprocess_kernel(knl) + knl = lp.get_one_scheduled_kernel(knl) + regs = 0 + reg_counter = RegisterUsageEstimator(knl) + for sched_item in knl.schedule: + if isinstance(sched_item, RunInstruction): + insn = knl.id_to_insn[sched_item.insn_id] + regs += reg_counter(insn.expression) + regs += reg_counter(insn.assignee) + return regs diff --git a/test/test_statistics.py b/test/test_statistics.py index a504761193fe4acb7dff9a4a9535efb7a74fe2a9..5426d6312a7cf0f79e964933a0e264495b68ea4b 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -28,6 +28,7 @@ from pyopencl.tools import ( # noqa as pytest_generate_tests) import loopy as lp from loopy.statistics import get_op_poly, get_gmem_access_poly, get_barrier_poly +from loopy.statistics import get_regs_per_thread import numpy as np @@ -510,6 +511,92 @@ def test_barrier_counter_barriers(): assert barrier_count == 50*10*2 +def test_reg_counter_basic(): + + knl = lp.make_kernel( + "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", + [ + """ + c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k] + e[i, k+1] = g[i,k]*h[i,k+1] + """ + ], + name="basic", assumptions="n,m,l >= 1") + + knl = lp.add_and_infer_dtypes(knl, + dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) + regs = get_regs_per_thread(knl) + assert regs == 8 + + +def test_reg_counter_reduction(): + + knl = lp.make_kernel( + "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", + [ + "c[i, j] = sum(k, a[i, k]*b[k, j])" + ], + name="matmul_serial", assumptions="n,m,l >= 1") + + knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) + regs = get_regs_per_thread(knl) + assert regs == 8 + + +def test_reg_counter_logic(): + + knl = lp.make_kernel( + "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", + [ + """ + e[i,k] = if(not(k<l-2) and k>6 or k/2==l, g[i,k]*2, g[i,k]+h[i,k]/2) + """ + ], + name="logic", assumptions="n,m,l >= 1") + + knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64)) + regs = get_regs_per_thread(knl) + assert regs == 14 + + +def test_reg_counter_specialops(): + + knl = lp.make_kernel( + "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", + [ + """ + c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0) + e[i, k] = (1+g[i,k])**(1+h[i,k+1]) + """ + ], + name="specialops", assumptions="n,m,l >= 1") + + knl = lp.add_and_infer_dtypes(knl, + dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) + regs = get_regs_per_thread(knl) + assert regs == 11 + + +def test_reg_counter_bitwise(): + + knl = lp.make_kernel( + "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", + [ + """ + c[i, j, k] = (a[i,j,k] | 1) + (b[i,j,k] & 1) + e[i, k] = (g[i,k] ^ k)*(~h[i,k+1]) + (g[i, k] << (h[i,k] >> k)) + """ + ], + name="bitwise", assumptions="n,m,l >= 1") + + knl = lp.add_and_infer_dtypes( + knl, dict( + a=np.int32, b=np.int32, + g=np.int64, h=np.int64)) + regs = get_regs_per_thread(knl) + assert regs == 12 + + def test_all_counters_parallel_matmul(): knl = lp.make_kernel( @@ -557,6 +644,9 @@ def test_all_counters_parallel_matmul(): assert f32coal == n*l + regs = get_regs_per_thread(knl) + assert regs == 8 + if __name__ == "__main__": if len(sys.argv) > 1: