From df53dfbd0a78b5d9adbea2b33102a33401048b91 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 26 Oct 2019 16:59:14 -0500 Subject: [PATCH] revamps statistics post root_kernel removal --- loopy/statistics.py | 70 +++++++++++---- test/test_statistics.py | 188 ++++++++++++++++++++-------------------- 2 files changed, 146 insertions(+), 112 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 86f39e55b..c8670e19f 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -39,8 +39,7 @@ from loopy.diagnostic import warn_with_kernel, LoopyError from loopy.symbolic import CoefficientCollector from pytools import ImmutableRecord, memoize_method from loopy.kernel.function_interface import CallableKernel -from loopy.kernel import LoopKernel -from loopy.program import make_program +from loopy.program import Program __doc__ = """ @@ -812,8 +811,8 @@ class CounterBase(CombineMapper): self.callables_table = callables_table self.kernel_rec = kernel_rec - from loopy.type_inference import TypeInferenceMapper - self.type_inf = TypeInferenceMapper(knl, callables_table) + from loopy.type_inference import TypeReader + self.type_inf = TypeReader(knl, callables_table) self.zero = get_kernel_zero_pwqpolynomial(self.knl) self.one = self.zero + 1 @@ -1382,6 +1381,13 @@ def add_assumptions_guard(kernel, pwqpolynomial): def count(kernel, set, space=None): + if isinstance(kernel, Program): + kernel_names = [i for i, clbl in six.iteritems(kernel.callables_table) + if isinstance(clbl, CallableKernel)] + if len(kernel_names) > 1: + raise LoopyError() + return count(kernel[kernel_names[0]], set, space) + try: if space is not None: set = set.align_params(space) @@ -1390,7 +1396,7 @@ def count(kernel, set, space=None): except AttributeError: pass - count = isl.PwQPolynomial.zero( + total_count = isl.PwQPolynomial.zero( set.space .drop_dims(dim_type.set, 0, set.dim(dim_type.set)) .add_dims(dim_type.set, 1)) @@ -1452,7 +1458,7 @@ def count(kernel, set, space=None): # }}} if bset_count is not None: - count += bset_count + total_count += bset_count is_subset = bset <= bset_rebuilt is_superset = bset >= bset_rebuilt @@ -1477,7 +1483,7 @@ def count(kernel, set, space=None): "number of integer points in your loop " "domain.") - return add_assumptions_guard(kernel, count) + return add_assumptions_guard(kernel, total_count) def get_unused_hw_axes_factor(knl, callables_table, insn, @@ -1552,7 +1558,6 @@ def count_insn_runs(knl, callables_table, insn, count_redundant_work, return c -@memoize_method def _get_insn_count(knl, callables_table, insn_id, subgroup_size, count_redundant_work, count_granularity=CountGranularity.WORKITEM): insn = knl.id_to_insn[insn_id] @@ -1657,7 +1662,8 @@ def _get_op_map_for_single_kernel(knl, callables_table, def get_op_map(program, numpy_types=True, count_redundant_work=False, - count_within_subscripts=True, subgroup_size=None): + count_within_subscripts=True, subgroup_size=None, + entrypoint=None): """Count the number of operations in a loopy kernel. @@ -1713,8 +1719,13 @@ def get_op_map(program, numpy_types=True, count_redundant_work=False, """ - if isinstance(program, LoopKernel): - program = make_program(program) + if entrypoint is None: + if len(program.entrypoints) > 1: + raise LoopyError("Must provide entrypoint") + + entrypoint = list(program.entrypoints)[0] + + assert entrypoint in program.entrypoints from loopy.preprocess import preprocess_program, infer_unknown_types program = preprocess_program(program) @@ -1729,7 +1740,7 @@ def get_op_map(program, numpy_types=True, count_redundant_work=False, DeprecationWarning, stacklevel=2) return _get_op_map_for_single_kernel( - program[program.name], program.callables_table, + program[entrypoint], program.callables_table, count_redundant_work=count_redundant_work, count_within_subscripts=count_within_subscripts, subgroup_size=subgroup_size) @@ -1848,7 +1859,7 @@ def _get_mem_access_map_for_single_kernel(knl, callables_table, def get_mem_access_map(program, numpy_types=None, count_redundant_work=False, - subgroup_size=None): + subgroup_size=None, entrypoint=None): """Count the number of memory accesses in a loopy kernel. :arg knl: A :class:`loopy.LoopKernel` whose memory accesses are to be @@ -1929,6 +1940,15 @@ def get_mem_access_map(program, numpy_types=None, count_redundant_work=False, # (now use these counts to, e.g., predict performance) """ + + if entrypoint is None: + if len(program.entrypoints) > 1: + raise LoopyError("Must provide entrypoint") + + entrypoint = list(program.entrypoints)[0] + + assert entrypoint in program.entrypoints + from loopy.preprocess import preprocess_program, infer_unknown_types program = preprocess_program(program) @@ -1942,7 +1962,7 @@ def get_mem_access_map(program, numpy_types=None, count_redundant_work=False, DeprecationWarning, stacklevel=2) return _get_mem_access_map_for_single_kernel( - program[program.name], program.callables_table, + program[entrypoint], program.callables_table, count_redundant_work=count_redundant_work, subgroup_size=subgroup_size) @@ -2004,7 +2024,7 @@ def _get_synchronization_map_for_single_kernel(knl, callables_table, return sync_map -def get_synchronization_map(program, subgroup_size=None): +def get_synchronization_map(program, subgroup_size=None, entrypoint=None): """Count the number of synchronization events each work-item encounters in a loopy kernel. @@ -2040,7 +2060,13 @@ def get_synchronization_map(program, subgroup_size=None): # (now use this count to, e.g., predict performance) """ + if entrypoint is None: + if len(program.entrypoints) > 1: + raise LoopyError("Must provide entrypoint") + + entrypoint = list(program.entrypoints)[0] + assert entrypoint in program.entrypoints from loopy.preprocess import preprocess_program, infer_unknown_types program = preprocess_program(program) @@ -2049,7 +2075,7 @@ def get_synchronization_map(program, subgroup_size=None): program = infer_unknown_types(program, expect_completion=True) return _get_synchronization_map_for_single_kernel( - program[program.name], program.callables_table, + program[entrypoint], program.callables_table, subgroup_size=subgroup_size) # }}} @@ -2083,7 +2109,7 @@ def _gather_access_footprints_for_single_kernel(kernel, ignore_uncountable): return write_footprints, read_footprints -def gather_access_footprints(program, ignore_uncountable=False): +def gather_access_footprints(program, ignore_uncountable=False, entrypoint=None): """Return a dictionary mapping ``(var_name, direction)`` to :class:`islpy.Set` instances capturing which indices of each the array *var_name* are read/written (where *direction* is either ``read`` or @@ -2094,6 +2120,14 @@ def gather_access_footprints(program, ignore_uncountable=False): nonlinear indices) """ + if entrypoint is None: + if len(program.entrypoints) > 1: + raise LoopyError("Must provide entrypoint") + + entrypoint = list(program.entrypoints)[0] + + assert entrypoint in program.entrypoints + # FIMXE: works only for one callable kernel till now. if len([in_knl_callable for in_knl_callable in program.callables_table.values() if isinstance(in_knl_callable, @@ -2112,7 +2146,7 @@ def gather_access_footprints(program, ignore_uncountable=False): read_footprints = [] write_footprints, read_footprints = _gather_access_footprints_for_single_kernel( - program[program.name], ignore_uncountable) + program[entrypoint], ignore_uncountable) write_footprints = AccessFootprintGatherer.combine(write_footprints) read_footprints = AccessFootprintGatherer.combine(read_footprints) diff --git a/test/test_statistics.py b/test/test_statistics.py index ef5450599..a1ee67a8d 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -67,15 +67,15 @@ def test_op_counter_basic(): m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP, knl.name)].eval_with_dict( + f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP, "basic")].eval_with_dict( params) - f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP, knl.name)].eval_with_dict( + f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP, "basic")].eval_with_dict( params) - f32div = op_map[lp.Op(np.float32, 'div', CG.SUBGROUP, knl.name)].eval_with_dict( + f32div = op_map[lp.Op(np.float32, 'div', CG.SUBGROUP, "basic")].eval_with_dict( params) - f64mul = op_map[lp.Op(np.dtype(np.float64), 'mul', CG.SUBGROUP, knl.name) + f64mul = op_map[lp.Op(np.dtype(np.float64), 'mul', CG.SUBGROUP, "basic") ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP, knl.name) + i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP, "basic") ].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert f32add == f32mul == f32div == n*m*ell*n_subgroups @@ -102,10 +102,10 @@ def test_op_counter_reduction(): m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP, knl.name)].eval_with_dict( - params) - f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul', CG.SUBGROUP, knl.name) - ].eval_with_dict(params) + f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP, + "matmul_serial")].eval_with_dict(params) + f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul', CG.SUBGROUP, + "matmul_serial")].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert f32add == f32mul == n*m*ell*n_subgroups @@ -138,13 +138,13 @@ def test_op_counter_logic(): m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP, knl.name)].eval_with_dict( + f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP, "logic")].eval_with_dict( params) - f64add = op_map[lp.Op(np.float64, 'add', CG.SUBGROUP, knl.name)].eval_with_dict( + f64add = op_map[lp.Op(np.float64, 'add', CG.SUBGROUP, "logic")].eval_with_dict( params) - f64div = op_map[lp.Op(np.dtype(np.float64), 'div', CG.SUBGROUP, knl.name) + f64div = op_map[lp.Op(np.dtype(np.float64), 'div', CG.SUBGROUP, "logic") ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP, knl.name) + i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP, "logic") ].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert f32mul == n*m*n_subgroups @@ -153,7 +153,7 @@ def test_op_counter_logic(): assert i32add == n*m*n_subgroups -def test_op_counter_specialops(): +def test_op_counter_special_ops(): knl = lp.make_kernel( "{[i,k,j]: 0<=i