diff --git a/loopy/statistics.py b/loopy/statistics.py index c233ab096ca1bb0cace5f9abcd29fae063713434..194775dba582ecfb2c0f6579765fc6a4d74bf6b2 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -32,7 +32,7 @@ from functools import reduce from loopy.kernel.data import ( MultiAssignmentBase, TemporaryVariable, AddressSpace) from loopy.diagnostic import warn_with_kernel, LoopyError -from pytools import Record +from pytools import Record, memoize_method __doc__ = """ @@ -1255,6 +1255,59 @@ def count_insn_runs(knl, insn, count_redundant_work, disregard_local_axes=False) else: return c + +@memoize_method +def _get_insn_count(knl, insn_id, subgroup_size, count_redundant_work, + count_granularity=CountGranularity.WORKITEM): + insn = knl.id_to_insn[insn_id] + + if count_granularity is None: + warn_with_kernel(knl, "get_insn_count_assumes_granularity", + "get_insn_count: No count granularity passed, " + "assuming %s granularity." + % (CountGranularity.WORKITEM)) + count_granularity == CountGranularity.WORKITEM + + if count_granularity == CountGranularity.WORKITEM: + return count_insn_runs( + knl, insn, count_redundant_work=count_redundant_work, + disregard_local_axes=False) + + ct_disregard_local = count_insn_runs( + knl, insn, disregard_local_axes=True, + count_redundant_work=count_redundant_work) + + if count_granularity == CountGranularity.WORKGROUP: + return ct_disregard_local + elif count_granularity == CountGranularity.SUBGROUP: + # get the group size + from loopy.symbolic import aff_to_expr + _, local_size = knl.get_grid_size_upper_bounds() + workgroup_size = 1 + if local_size: + for size in local_size: + s = aff_to_expr(size) + if not isinstance(s, int): + raise LoopyError("Cannot count insn with %s granularity, " + "work-group size is not integer: %s" + % (CountGranularity.SUBGROUP, local_size)) + workgroup_size *= s + + warn_with_kernel(knl, "insn_count_subgroups_upper_bound", + "get_insn_count: when counting instruction %s with " + "count_granularity=%s, using upper bound for work-group size " + "(%d work-items) to compute sub-groups per work-group. When " + "multiple device programs present, actual sub-group count may be" + "lower." % (insn_id, CountGranularity.SUBGROUP, workgroup_size)) + + from pytools import div_ceil + return ct_disregard_local*div_ceil(workgroup_size, subgroup_size) + else: + # this should not happen since this is enforced in Op/MemAccess + raise ValueError("get_insn_count: count_granularity '%s' is" + "not allowed. count_granularity options: %s" + % (count_granularity, CountGranularity.ALL+[None])) + # }}} @@ -1360,77 +1413,18 @@ def get_op_map(knl, numpy_types=True, count_redundant_work=False, "must be integer, 'guess', or, if you're feeling " "lucky, None." % (subgroup_size)) - # ------------------------------ - #class CacheHolder(object): - # pass - - #cache_holder = CacheHolder() - #from pytools import memoize_in - - #@memoize_in(cache_holder, "insn_count") - def get_insn_count(knl, insn, count_granularity=CountGranularity.WORKITEM): - - if count_granularity is None: - warn_with_kernel(knl, "get_insn_count_assumes_granularity", - "get_insn_count: No count granularity passed for " - "Op, assuming %s granularity." - % (CountGranularity.WORKITEM)) - count_granularity == CountGranularity.WORKITEM - - if count_granularity == CountGranularity.WORKITEM: - return count_insn_runs( - knl, insn, count_redundant_work=count_redundant_work, - disregard_local_axes=False) - - ct_disregard_local = count_insn_runs( - knl, insn, disregard_local_axes=True, - count_redundant_work=count_redundant_work) - - if count_granularity == CountGranularity.WORKGROUP: - return ct_disregard_local - elif count_granularity == CountGranularity.SUBGROUP: - # get the group size - from loopy.symbolic import aff_to_expr - _, local_size = knl.get_grid_size_upper_bounds() - workgroup_size = 1 - if local_size: - for size in local_size: - s = aff_to_expr(size) - if not isinstance(s, int): - raise LoopyError("Cannot count insn with %s granularity, " - "work-group size is not integer: %s" - % (CountGranularity.SUBGROUP, local_size)) - workgroup_size *= s - - warn_with_kernel(knl, "insn_count_subgroups_upper_bound", - "get_insn_count: when counting instruction %s with " - "count_granularity=%s, using upper bound for work-group size " - "(%d work-items) to compute sub-groups per work-group. When " - "multiple device programs present, actual sub-group count may be" - "lower." % (insn, CountGranularity.SUBGROUP, workgroup_size)) - - from pytools import div_ceil - return ct_disregard_local*div_ceil(workgroup_size, subgroup_size) - else: - # this should not happen since this is enforced in Op - raise ValueError("get_insn_count: count_granularity '%s' is" - "not allowed. count_granularity options: %s" - % (count_granularity, CountGranularity.ALL+[None])) - # ------------------------------ - op_map = ToCountMap() op_counter = ExpressionOpCounter(knl) for insn in knl.instructions: if isinstance(insn, (CallInstruction, CInstruction, Assignment)): ops = op_counter(insn.assignee) + op_counter(insn.expression) - #op_map = op_map + ops*count_insn_runs( - # knl, insn, - # count_redundant_work=count_redundant_work) for key, val in six.iteritems(ops): op_map = ( op_map + ToCountMap({key: val}) - * get_insn_count(knl, insn, key.count_granularity)) + * _get_insn_count(knl, insn.id, subgroup_size, + count_redundant_work, + key.count_granularity)) elif isinstance(insn, (NoOpInstruction, BarrierInstruction)): pass @@ -1594,63 +1588,6 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, "must be integer, 'guess', or, if you're feeling " "lucky, None." % (subgroup_size)) - class CacheHolder(object): - pass - - cache_holder = CacheHolder() - from pytools import memoize_in - - @memoize_in(cache_holder, "insn_count") - def get_insn_count(knl, insn_id, count_granularity=CountGranularity.WORKITEM): - insn = knl.id_to_insn[insn_id] - - if count_granularity is None: - warn_with_kernel(knl, "get_insn_count_assumes_granularity", - "get_insn_count: No count granularity passed for " - "MemAccess, assuming %s granularity." - % (CountGranularity.WORKITEM)) - count_granularity == CountGranularity.WORKITEM - - if count_granularity == CountGranularity.WORKITEM: - return count_insn_runs( - knl, insn, count_redundant_work=count_redundant_work, - disregard_local_axes=False) - - ct_disregard_local = count_insn_runs( - knl, insn, disregard_local_axes=True, - count_redundant_work=count_redundant_work) - - if count_granularity == CountGranularity.WORKGROUP: - return ct_disregard_local - elif count_granularity == CountGranularity.SUBGROUP: - # get the group size - from loopy.symbolic import aff_to_expr - _, local_size = knl.get_grid_size_upper_bounds() - workgroup_size = 1 - if local_size: - for size in local_size: - s = aff_to_expr(size) - if not isinstance(s, int): - raise LoopyError("Cannot count insn with %s granularity, " - "work-group size is not integer: %s" - % (CountGranularity.SUBGROUP, local_size)) - workgroup_size *= s - - warn_with_kernel(knl, "insn_count_subgroups_upper_bound", - "get_insn_count: when counting instruction %s with " - "count_granularity=%s, using upper bound for work-group size " - "(%d work-items) to compute sub-groups per work-group. When " - "multiple device programs present, actual sub-group count may be" - "lower." % (insn_id, CountGranularity.SUBGROUP, workgroup_size)) - - from pytools import div_ceil - return ct_disregard_local*div_ceil(workgroup_size, subgroup_size) - else: - # this should not happen since this is enforced in MemAccess - raise ValueError("get_insn_count: count_granularity '%s' is" - "not allowed. count_granularity options: %s" - % (count_granularity, CountGranularity.ALL+[None])) - knl = infer_unknown_types(knl, expect_completion=True) knl = preprocess_kernel(knl) @@ -1679,14 +1616,18 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, access_map = ( access_map + ToCountMap({key: val}) - * get_insn_count(knl, insn.id, key.count_granularity)) + * _get_insn_count(knl, insn.id, subgroup_size, + count_redundant_work, + key.count_granularity)) for key, val in six.iteritems(access_assignee.count_map): access_map = ( access_map + ToCountMap({key: val}) - * get_insn_count(knl, insn.id, key.count_granularity)) + * _get_insn_count(knl, insn.id, subgroup_size, + count_redundant_work, + key.count_granularity)) elif isinstance(insn, (NoOpInstruction, BarrierInstruction)): pass