From 118cb24becb9429ecd8d352465673ac1a0eeeeb7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 31 May 2019 19:01:58 -0500 Subject: [PATCH] Fix loopy.statistics for kernel callables This is a large refactoring, with many pieces: - Counts from subkernels are incorporated using subst_into_{pwqpolynomial,guarded_pwqpolynomial,to_count_map}. This replaces a prior, broken scheme that existed on the kernel callables branch. - Separate ToCountMap and ToCountPolynomialMap, i.e. separate to-count map types by their value type. The latter type now knows (and checks) its isl space. - The numpy_types argument is now deprecated and ignored, it did not seem to do anything previously. - Introduce Sync() count key for synchronization counting. - Code/robustness cleanups in the ToCountMap* types. - All op descriptors now carry a kernel_name. There are still a few FIMXEs, mainly the SUBGROUP granularity and the footprint gatherer. --- loopy/__init__.py | 4 +- loopy/isl_helpers.py | 1 + loopy/statistics.py | 945 ++++++++++++++++++++++------------------ test/test_statistics.py | 68 ++- 4 files changed, 571 insertions(+), 447 deletions(-) diff --git a/loopy/__init__.py b/loopy/__init__.py index a70adf398..fd6c8770c 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -131,7 +131,7 @@ from loopy.type_inference import infer_unknown_types from loopy.preprocess import (preprocess_kernel, realize_reduction, preprocess_program) from loopy.schedule import generate_loop_schedules, get_one_scheduled_kernel -from loopy.statistics import (ToCountMap, CountGranularity, stringify_stats_mapping, +from loopy.statistics import (ToCountMap, CountGranularity, Op, MemAccess, get_op_map, get_mem_access_map, get_synchronization_map, gather_access_footprints, gather_access_footprint_bytes) @@ -269,7 +269,7 @@ __all__ = [ "PreambleInfo", "generate_code", "generate_code_v2", "generate_body", - "ToCountMap", "CountGranularity", "stringify_stats_mapping", "Op", + "ToCountMap", "CountGranularity", "Op", "MemAccess", "get_op_map", "get_mem_access_map", "get_synchronization_map", "gather_access_footprints", "gather_access_footprint_bytes", diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 0eaba8322..0cbd18599 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -828,6 +828,7 @@ def get_param_subst_domain(new_space, base_obj, subst_dict): def subst_into_pwqpolynomial(new_space, poly, subst_dict): if not poly.get_pieces(): + assert new_space.is_params() result = isl.PwQPolynomial.zero(new_space.insert_dims(dim_type.out, 0, 1)) assert result.dim(dim_type.out) == 1 return result diff --git a/loopy/statistics.py b/loopy/statistics.py index 5e4b1ecf1..2c3d4f36f 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1,6 +1,10 @@ from __future__ import division, absolute_import, print_function -__copyright__ = "Copyright (C) 2015 James Stevens" +__copyright__ = """ +Copyright (C) 2015 James Stevens +Copyright (C) 2018 Kaushik Kulkarni +Copyright (C) 2019 Andreas Kloeckner +""" __license__ = """ Permission is hereby granted, free of charge, to any person obtaining a copy @@ -22,19 +26,19 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from functools import partial import six import loopy as lp from islpy import dim_type import islpy as isl from pymbolic.mapper import CombineMapper -from functools import reduce from loopy.kernel.data import ( MultiAssignmentBase, TemporaryVariable, AddressSpace) from loopy.diagnostic import warn_with_kernel, LoopyError from loopy.symbolic import CoefficientCollector -from pytools import Record, memoize_method -from loopy.kernel.function_interface import ScalarCallable, CallableKernel +from pytools import ImmutableRecord, memoize_method +from loopy.kernel.function_interface import CallableKernel from loopy.kernel import LoopKernel from loopy.program import make_program @@ -44,6 +48,7 @@ __doc__ = """ .. currentmodule:: loopy .. autoclass:: ToCountMap +.. autoclass:: ToCountPolynomialMap .. autoclass:: CountGranularity .. autoclass:: Op .. autoclass:: MemAccess @@ -63,13 +68,29 @@ __doc__ = """ """ -# FIXME: this is broken for the callable kernel design. -# - The variable name, what if multiple kernels use the same name?(needs a -# different MemAccessInfo) -# - We should also add the cumulative effect on the arguments of callee kernels -# into the caller kernel -# - Make changes to MemAccessInfo to include the effect of several kernels. -# - Renovate `count`. +# FIXME: +# - The SUBGROUP granularity is completely broken if the root kernel +# contains the grid and the operations get counted in the callee. +# To test, most of those are set to WORKITEM instead below (marked +# with FIXMEs). This leads to value mismatches and key errors in +# the tests. +# - Currently, nothing prevents summation across different +# granularities, which is guaranteed to yield bogus results. +# - AccessFootprintGatherer needs to be redone to match get_op_map and +# get_mem_access_map style +# - Test for the subkernel functionality need to be written + + +def get_kernel_parameter_space(kernel): + return isl.Space.create_from_names(kernel.isl_context, + set=[], params=kernel.outer_params()).params() + + +def get_kernel_zero_pwqpolynomial(kernel): + space = get_kernel_parameter_space(kernel) + space = space.insert_dims(dim_type.out, 0, 1) + return isl.PwQPolynomial.zero(space) + # {{{ GuardedPwQPolynomial @@ -87,6 +108,10 @@ class GuardedPwQPolynomial(object): assert (_get_param_tuple(pwqpolynomial.space) == _get_param_tuple(valid_domain.space)) + @property + def space(self): + return self.valid_domain.space + def __add__(self, other): if isinstance(other, GuardedPwQPolynomial): return GuardedPwQPolynomial( @@ -143,7 +168,20 @@ class GuardedPwQPolynomial(object): # {{{ ToCountMap class ToCountMap(object): - """Maps any type of key to an arithmetic type. + """A map from work descriptors like :class:`Op` and :class:`MemAccess` + to any arithmetic type. + + .. automethod:: __getitem__ + .. automethod:: __str__ + .. automethod:: __repr__ + .. automethod:: __len__ + .. automethod:: get + .. automethod:: items + .. automethod:: keys + .. automethod:: values + + .. automethod:: copy + .. automethod:: with_set_attributes .. automethod:: filter_by .. automethod:: filter_by_func @@ -154,23 +192,20 @@ class ToCountMap(object): """ - def __init__(self, init_dict=None, val_type=GuardedPwQPolynomial): - if init_dict is None: - init_dict = {} + def __init__(self, count_map=None): + if count_map is None: + count_map = {} - for val in init_dict.values(): - if isinstance(val, isl.PwQPolynomial): - assert val.dim(dim_type.out) - elif isinstance(val, GuardedPwQPolynomial): - assert val.pwqpolynomial.dim(dim_type.out) - self.count_map = init_dict - self.val_type = val_type + self.count_map = count_map + + def _zero(self): + return 0 def __add__(self, other): result = self.count_map.copy() for k, v in six.iteritems(other.count_map): result[k] = self.count_map.get(k, 0) + v - return ToCountMap(result, self.val_type) + return self.copy(count_map=result) def __radd__(self, other): if other != 0: @@ -178,32 +213,18 @@ class ToCountMap(object): "to {0} {1}. ToCountMap may only be added to " "0 and other ToCountMap objects." .format(type(other), other)) + return self def __mul__(self, other): - if isinstance(other, GuardedPwQPolynomial): - return ToCountMap(dict( - (index, value*other) - for index, value in six.iteritems(self.count_map))) - else: - raise ValueError("ToCountMap: Attempted to multiply " - "ToCountMap by {0} {1}." - .format(type(other), other)) + return self.copy(dict( + (index, value*other) + for index, value in six.iteritems(self.count_map))) __rmul__ = __mul__ def __getitem__(self, index): - try: - return self.count_map[index] - except KeyError: - #TODO what is the best way to handle this? - if self.val_type is GuardedPwQPolynomial: - return GuardedPwQPolynomial.zero() - else: - return 0 - - def __setitem__(self, index, value): - self.count_map[index] = value + return self.count_map[index] def __repr__(self): return repr(self.count_map) @@ -225,17 +246,19 @@ class ToCountMap(object): def keys(self): return self.count_map.keys() - def pop(self, item): - return self.count_map.pop(item) + def values(self): + return self.count_map.values() + + def copy(self, count_map=None): + if count_map is None: + count_map = self.count_map - def copy(self): - return ToCountMap(dict(self.count_map), self.val_type) + return type(self)(count_map=count_map) def with_set_attributes(self, **kwargs): - return ToCountMap(dict( + return self.copy(count_map=dict( (key.copy(**kwargs), val) - for key, val in six.iteritems(self.count_map)), - self.val_type) + for key, val in six.iteritems(self.count_map))) def filter_by(self, **kwargs): """Remove items without specified key fields. @@ -262,28 +285,25 @@ class ToCountMap(object): """ - result_map = ToCountMap(val_type=self.val_type) - - from loopy.types import to_loopy_type - if 'dtype' in kwargs.keys(): - kwargs['dtype'] = [to_loopy_type(d) for d in kwargs['dtype']] - - # for each item in self.count_map - for self_key, self_val in self.items(): - try: - # check to see if key attribute values match all filters - for arg_field, allowable_vals in kwargs.items(): - attr_val = getattr(self_key, arg_field) - # see if the value is in the filter list - if attr_val not in allowable_vals: - break - else: # loop terminated without break or error - result_map[self_key] = self_val - except(AttributeError): - # the field passed is not a field of this key - continue - - return result_map + new_count_map = {} + + class _Sentinel: + pass + + new_kwargs = {} + for arg_field, allowable_vals in six.iteritems(kwargs): + if arg_field == "dtype": + from loopy.types import to_loopy_type + allowable_vals = [to_loopy_type(dtype) for dtype in allowable_vals] + + new_kwargs[arg_field] = allowable_vals + + for key, val in six.iteritems(self.count_map): + if all(getattr(key, arg_field, _Sentinel) in allowable_vals + for arg_field, allowable_vals in six.iteritems(new_kwargs)): + new_count_map[key] = val + + return self.copy(count_map=new_count_map) def filter_by_func(self, func): """Keep items that pass a test. @@ -310,14 +330,13 @@ class ToCountMap(object): """ - result_map = ToCountMap(val_type=self.val_type) + new_count_map = {} - # for each item in self.count_map, call func on the key - for self_key, self_val in self.items(): + for self_key, self_val in six.iteritems(self.count_map): if func(self_key): - result_map[self_key] = self_val + new_count_map[self_key] = self_val - return result_map + return self.copy(count_map=new_count_map) def group_by(self, *args): """Group map items together, distinguishing by only the key fields @@ -365,7 +384,7 @@ class ToCountMap(object): """ - result_map = ToCountMap(val_type=self.val_type) + new_count_map = {} # make sure all item keys have same type if self.count_map: @@ -374,22 +393,17 @@ class ToCountMap(object): raise ValueError("ToCountMap: group_by() function may only " "be used on ToCountMaps with uniform keys") else: - return result_map - - # for each item in self.count_map - for self_key, self_val in self.items(): - new_key = key_type() + return self - # set all specified fields - for field in args: - setattr(new_key, field, getattr(self_key, field)) + for self_key, self_val in six.iteritems(self.count_map): + new_key = key_type( + **dict( + (field, getattr(self_key, field)) + for field in args)) - if new_key in result_map.keys(): - result_map[new_key] += self_val - else: - result_map[new_key] = self_val + new_count_map[new_key] = new_count_map.get(new_key, 0) + self_val - return result_map + return self.copy(count_map=new_count_map) def to_bytes(self): """Convert counts to bytes using data type in map key. @@ -422,34 +436,69 @@ class ToCountMap(object): """ - result = self.copy() + new_count_map = {} - for key, val in self.items(): - bytes_processed = int(key.dtype.itemsize) * val - result[key] = bytes_processed + for key, val in six.iteritems(self.count_map): + new_count_map[key] = int(key.dtype.itemsize) * val - #TODO again, is this okay? - result.val_type = int - - return result + return self.copy(new_count_map) def sum(self): - """Add all counts in ToCountMap. - - :return: An :class:`islpy.PwQPolynomial` or :class:`int` containing the - sum of counts. + """:return: A sum of the values of the dictionary.""" - """ - - if self.val_type is GuardedPwQPolynomial: - total = GuardedPwQPolynomial.zero() - else: - total = 0 + total = self._zero() - for k, v in self.items(): + for k, v in six.iteritems(self.count_map): total += v + return total +# }}} + + +# {{{ ToCountPolynomialMap + +class ToCountPolynomialMap(ToCountMap): + """Maps any type of key to a :class:`islpy.PwQPolynomial` or a + :class:`GuardedPwQPolynomial`. + """ + + def __init__(self, space, count_map=None): + if not isinstance(space, isl.Space): + raise TypeError( + "first argument to ToCountPolynomialMap must be " + "of type islpy.Space") + + assert space.is_params() + self.space = space + + space_param_tuple = _get_param_tuple(space) + + for key, val in six.iteritems(count_map): + if isinstance(val, isl.PwQPolynomial): + assert val.dim(dim_type.out) == 1 + elif isinstance(val, GuardedPwQPolynomial): + assert val.pwqpolynomial.dim(dim_type.out) == 1 + else: + raise TypeError("unexpected value type") + + assert _get_param_tuple(val.space) == space_param_tuple + + super(ToCountPolynomialMap, self).__init__(count_map) + + def _zero(self): + space = self.space.insert_dims(dim_type.out, 0, 1) + return isl.PwQPolynomial.zero(space) + + def copy(self, count_map=None, space=None): + if count_map is None: + count_map = self.count_map + + if space is None: + space = self.space + + return type(self)(space, count_map) + #TODO test and document def eval(self, params): result = self.copy() @@ -458,12 +507,11 @@ class ToCountMap(object): result.val_type = int return result - def eval_and_sum(self, params): - """Add all counts in :class:`ToCountMap` and evaluate with provided - parameter dict. + def eval_and_sum(self, params=None): + """Add all counts and evaluate with provided parameter dict *params* - :return: An :class:`int` containing the sum of all counts in the - :class:`ToCountMap` evaluated with the parameters provided. + :return: An :class:`int` containing the sum of all counts + evaluated with the parameters provided. Example usage:: @@ -478,6 +526,9 @@ class ToCountMap(object): # (now use these counts to, e.g., predict performance) """ + if params is None: + params = {} + return self.sum().eval_with_dict(params) # }}} @@ -504,35 +555,29 @@ def subst_into_guarded_pwqpolynomial(new_space, guarded_poly, subst_dict): def subst_into_to_count_map(space, tcm, subst_dict): from loopy.isl_helpers import subst_into_pwqpolynomial - result = {} + new_count_map = {} for key, value in six.iteritems(tcm.count_map): - # FIXME: This strips away the guards. Rather than being stripped, - # they should also have the substitution applied if isinstance(value, GuardedPwQPolynomial): - result[key] = subst_into_guarded_pwqpolynomial(space, value, subst_dict) + new_count_map[key] = subst_into_guarded_pwqpolynomial( + space, value, subst_dict) elif isinstance(value, isl.PwQPolynomial): - result[key] = subst_into_pwqpolynomial(space, value, subst_dict) + new_count_map[key] = subst_into_pwqpolynomial(space, value, subst_dict) elif isinstance(value, int): - result[key] = value + new_count_map[key] = value else: raise ValueError("unexpected value type") - return ToCountMap(result, val_type=isl.PwQPolynomial) + return tcm.copy(space=space, count_map=new_count_map) # }}} -def stringify_stats_mapping(m): - result = "" - for key in sorted(m.keys(), key=lambda k: str(k)): - result += ("%s : %s\n" % (key, m[key])) - return result - +# {{{ CountGranularity -class CountGranularity: +class CountGranularity(object): """Strings specifying whether an operation should be counted once per *work-item*, *sub-group*, or *work-group*. @@ -558,10 +603,12 @@ class CountGranularity: WORKGROUP = "workgroup" ALL = [WORKITEM, SUBGROUP, WORKGROUP] +# }}} + # {{{ Op descriptor -class Op(Record): +class Op(ImmutableRecord): """A descriptor for a type of arithmetic operation. .. attribute:: dtype @@ -599,18 +646,14 @@ class Op(Record): raise ValueError("Op.__init__: count_granularity '%s' is " "not allowed. count_granularity options: %s" % (count_granularity, CountGranularity.ALL+[None])) - if dtype is None: - Record.__init__(self, dtype=dtype, name=name, - count_granularity=count_granularity, - kernel_name=kernel_name) - else: + + if dtype is not None: from loopy.types import to_loopy_type - Record.__init__(self, dtype=to_loopy_type(dtype), name=name, - count_granularity=count_granularity, - kernel_name=kernel_name) + dtype = to_loopy_type(dtype) - def __hash__(self): - return hash(repr(self)) + super(Op, self).__init__(dtype=dtype, name=name, + count_granularity=count_granularity, + kernel_name=kernel_name) def __repr__(self): # Record.__repr__ overridden for consistent ordering and conciseness @@ -625,7 +668,7 @@ class Op(Record): # {{{ MemAccess descriptor -class MemAccess(Record): +class MemAccess(ImmutableRecord): """A descriptor for a type of memory access. .. attribute:: mtype @@ -698,24 +741,19 @@ class MemAccess(Record): "not allowed. count_granularity options: %s" % (count_granularity, CountGranularity.ALL+[None])) - if dtype is None: - Record.__init__(self, mtype=mtype, dtype=dtype, lid_strides=lid_strides, - gid_strides=gid_strides, direction=direction, - variable=variable, variable_tag=variable_tag, - count_granularity=count_granularity, - kernel_name=kernel_name) - else: + if dtype is not None: from loopy.types import to_loopy_type - Record.__init__(self, mtype=mtype, dtype=to_loopy_type(dtype), - lid_strides=lid_strides, gid_strides=gid_strides, - direction=direction, variable=variable, - variable_tag=variable_tag, - count_granularity=count_granularity, - kernel_name=kernel_name) + dtype = to_loopy_type(dtype) + + super(MemAccess, self).__init__(mtype=mtype, dtype=dtype, + lid_strides=lid_strides, gid_strides=gid_strides, + direction=direction, variable=variable, + variable_tag=variable_tag, + count_granularity=count_granularity, + kernel_name=kernel_name) def __hash__(self): - # Note that this means lid_strides and gid_strides must be sorted - # in self.__repr__() + # dicts in gid_strides and lid_strides aren't natively hashable return hash(repr(self)) def __repr__(self): @@ -736,29 +774,97 @@ class MemAccess(Record): # }}} -# {{{ counter base +# {{{ Sync descriptor + +class Sync(ImmutableRecord): + """A descriptor for a type of synchronization. + + .. attribute:: kind + + A string describing the synchronization kind, e.g. ``"barrier_global"`` or + ``"barrier_local"`` or ``"kernel_launch"``. + + .. attribute:: kernel_name + + A :class:`str` representing the kernel name where the operation occurred. + """ + + def __init__(self, kind=None, kernel_name=None): + super(Sync, self).__init__(kind=kind, kernel_name=kernel_name) + + def __repr__(self): + # Record.__repr__ overridden for consistent ordering and conciseness + return "Sync(%s, %s)" % (self.kind, self.kernel_name) + +# }}} + + +# {{{ CounterBase class CounterBase(CombineMapper): - def __init__(self, knl, callables_table): + def __init__(self, knl, callables_table, kernel_rec): self.knl = knl self.callables_table = callables_table + self.kernel_rec = kernel_rec + from loopy.type_inference import TypeInferenceMapper self.type_inf = TypeInferenceMapper(knl, callables_table) + self.zero = get_kernel_zero_pwqpolynomial(self.knl) + self.one = self.zero + 1 + + @property + @memoize_method + def param_space(self): + return get_kernel_parameter_space(self.knl) + + def new_poly_map(self, count_map): + return ToCountPolynomialMap(self.param_space, count_map) + + def new_zero_poly_map(self): + return self.new_poly_map({}) + def combine(self, values): return sum(values) def map_constant(self, expr): - return ToCountMap() + return self.new_zero_poly_map() def map_call(self, expr): - return self.rec(expr.parameters) + from loopy.symbolic import ResolvedFunction + assert isinstance(expr.function, ResolvedFunction) + clbl = self.callables_table[expr.function.name] + + from loopy.kernel.function_interface import CallableKernel + from loopy.kernel.data import ValueArg + if isinstance(clbl, CallableKernel): + sub_result = self.kernel_rec(clbl.subkernel) + + assert len(clbl.subkernel.args) == len(expr.parameters) + arg_dict = dict( + (arg.name, value) + for arg, value in zip( + clbl.subkernel.args, + expr.parameters) + if isinstance(arg, ValueArg)) + + return subst_into_to_count_map( + self.param_space, + sub_result, arg_dict) \ + + self.rec(expr.parameters) + + else: + raise NotImplementedError() + + def map_call_with_kwargs(self, expr): + # FIXME + raise NotImplementedError() def map_sum(self, expr): if expr.children: return sum(self.rec(child) for child in expr.children) else: - return ToCountMap() + return self.new_zero_poly_map() map_product = map_sum @@ -798,68 +904,82 @@ class CounterBase(CombineMapper): # {{{ ExpressionOpCounter class ExpressionOpCounter(CounterBase): - def __init__(self, knl, callables_table, count_within_subscripts=True): - self.knl = knl - self.callables_table = callables_table + def __init__(self, knl, callables_table, kernel_rec, + count_within_subscripts=True): + super(ExpressionOpCounter, self).__init__( + knl, callables_table, kernel_rec) self.count_within_subscripts = count_within_subscripts - from loopy.type_inference import TypeInferenceMapper - self.type_inf = TypeInferenceMapper(knl, callables_table) + + # FIXME: Revert to SUBGROUP + arithmetic_count_granularity = CountGranularity.WORKITEM def combine(self, values): return sum(values) def map_constant(self, expr): - return ToCountMap() + return self.new_zero_poly_map() map_tagged_variable = map_constant map_variable = map_constant def map_call(self, expr): from loopy.symbolic import ResolvedFunction - if isinstance(expr.function, ResolvedFunction): - function_identifier = self.callables_table[ - expr.function.name].name + assert isinstance(expr.function, ResolvedFunction) + clbl = self.callables_table[expr.function.name] + + from loopy.kernel.function_interface import CallableKernel + if not isinstance(clbl, CallableKernel): + return self.new_poly_map( + {Op(dtype=self.type_inf(expr), + name='func:'+clbl.name, + count_granularity=self.arithmetic_count_granularity, + kernel_name=self.knl.name): self.one} + ) + self.rec(expr.parameters) else: - function_identifier = expr.function.name - - return ToCountMap( - {Op(dtype=self.type_inf(expr), - name='func:'+function_identifier, - count_granularity=CountGranularity.SUBGROUP): 1} - ) + self.rec(expr.parameters) + return super(ExpressionOpCounter, self).map_call(expr) def map_subscript(self, expr): if self.count_within_subscripts: return self.rec(expr.index) else: - return ToCountMap() + return self.new_zero_poly_map() + + def map_sub_array_ref(self, expr): + # generates an array view, considered free + return self.new_zero_poly_map() def map_sum(self, expr): assert expr.children - return ToCountMap( + return self.new_poly_map( {Op(dtype=self.type_inf(expr), name='add', - count_granularity=CountGranularity.SUBGROUP): - len(expr.children)-1} + count_granularity=self.arithmetic_count_granularity, + kernel_name=self.knl.name): + self.zero + (len(expr.children)-1)} ) + sum(self.rec(child) for child in expr.children) def map_product(self, expr): from pymbolic.primitives import is_zero assert expr.children - return sum(ToCountMap({Op(dtype=self.type_inf(expr), + return sum(self.new_poly_map({Op(dtype=self.type_inf(expr), name='mul', - count_granularity=CountGranularity.SUBGROUP): 1}) + count_granularity=( + self.arithmetic_count_granularity), + kernel_name=self.knl.name): self.one}) + self.rec(child) for child in expr.children if not is_zero(child + 1)) + \ - ToCountMap({Op(dtype=self.type_inf(expr), + self.new_poly_map({Op(dtype=self.type_inf(expr), name='mul', - count_granularity=CountGranularity.SUBGROUP): -1}) + count_granularity=( + self.arithmetic_count_granularity), + kernel_name=self.knl.name): -self.one}) def map_quotient(self, expr, *args): - return ToCountMap({Op(dtype=self.type_inf(expr), + return self.new_poly_map({Op(dtype=self.type_inf(expr), name='div', - count_granularity=CountGranularity.SUBGROUP): 1}) \ + count_granularity=self.arithmetic_count_granularity, + kernel_name=self.knl.name): self.one}) \ + self.rec(expr.numerator) \ + self.rec(expr.denominator) @@ -867,32 +987,36 @@ class ExpressionOpCounter(CounterBase): map_remainder = map_quotient def map_power(self, expr): - return ToCountMap({Op(dtype=self.type_inf(expr), + return self.new_poly_map({Op(dtype=self.type_inf(expr), name='pow', - count_granularity=CountGranularity.SUBGROUP): 1}) \ + count_granularity=self.arithmetic_count_granularity, + kernel_name=self.knl.name): self.one}) \ + self.rec(expr.base) \ + self.rec(expr.exponent) def map_left_shift(self, expr): - return ToCountMap({Op(dtype=self.type_inf(expr), + return self.new_poly_map({Op(dtype=self.type_inf(expr), name='shift', - count_granularity=CountGranularity.SUBGROUP): 1}) \ + count_granularity=self.arithmetic_count_granularity, + kernel_name=self.knl.name): self.one}) \ + self.rec(expr.shiftee) \ + self.rec(expr.shift) map_right_shift = map_left_shift def map_bitwise_not(self, expr): - return ToCountMap({Op(dtype=self.type_inf(expr), + return self.new_poly_map({Op(dtype=self.type_inf(expr), name='bw', - count_granularity=CountGranularity.SUBGROUP): 1}) \ + count_granularity=self.arithmetic_count_granularity, + kernel_name=self.knl.name): self.one}) \ + self.rec(expr.child) def map_bitwise_or(self, expr): - return ToCountMap({Op(dtype=self.type_inf(expr), + return self.new_poly_map({Op(dtype=self.type_inf(expr), name='bw', - count_granularity=CountGranularity.SUBGROUP): - len(expr.children)-1}) \ + count_granularity=self.arithmetic_count_granularity, + kernel_name=self.knl.name): + self.zero + (len(expr.children)-1)}) \ + sum(self.rec(child) for child in expr.children) map_bitwise_xor = map_bitwise_or @@ -913,9 +1037,10 @@ class ExpressionOpCounter(CounterBase): + self.rec(expr.else_) def map_min(self, expr): - return ToCountMap({Op(dtype=self.type_inf(expr), + return self.new_poly_map({Op(dtype=self.type_inf(expr), name='maxmin', - count_granularity=CountGranularity.SUBGROUP): + count_granularity=self.arithmetic_count_granularity, + kernel_name=self.knl.name): len(expr.children)-1}) \ + sum(self.rec(child) for child in expr.children) @@ -956,6 +1081,8 @@ class _IndexStrideCoefficientCollector(CoefficientCollector): # }}} +# {{{ _get_lid_and_gid_strides + def _get_lid_and_gid_strides(knl, array, index): # find all local and global index tags and corresponding inames from loopy.symbolic import get_dependencies @@ -1024,28 +1151,50 @@ def _get_lid_and_gid_strides(knl, array, index): return get_iname_strides(lid_to_iname), get_iname_strides(gid_to_iname) +# }}} + + +# {{{ MemAccessCounterBase + +class MemAccessCounterBase(CounterBase): + def map_sub_array_ref(self, expr): + # generates an array view, considered free + return self.new_zero_poly_map() + + def map_call(self, expr): + from loopy.symbolic import ResolvedFunction + assert isinstance(expr.function, ResolvedFunction) + clbl = self.callables_table[expr.function.name] + + from loopy.kernel.function_interface import CallableKernel + if not isinstance(clbl, CallableKernel): + return self.rec(expr.parameters) + else: + return super(MemAccessCounterBase, self).map_call(expr) -class MemAccessCounter(CounterBase): - pass +# }}} # {{{ LocalMemAccessCounter -class LocalMemAccessCounter(MemAccessCounter): +class LocalMemAccessCounter(MemAccessCounterBase): + # FIXME: Revert to SUBGROUP + local_mem_count_granularity = CountGranularity.WORKITEM + def count_var_access(self, dtype, name, index): - sub_map = ToCountMap() + count_map = {} if name in self.knl.temporary_variables: array = self.knl.temporary_variables[name] if isinstance(array, TemporaryVariable) and ( array.address_space == AddressSpace.LOCAL): if index is None: # no subscript - sub_map[MemAccess( + count_map[MemAccess( mtype='local', dtype=dtype, - count_granularity=CountGranularity.SUBGROUP) - ] = 1 - return sub_map + count_granularity=self.local_mem_count_granularity, + kernel_name=self.knl.name)] = self.one + return self.new_poly_map(count_map) array = self.knl.temporary_variables[name] @@ -1057,15 +1206,16 @@ class LocalMemAccessCounter(MemAccessCounter): lid_strides, gid_strides = _get_lid_and_gid_strides( self.knl, array, index_tuple) - sub_map[MemAccess( + count_map[MemAccess( mtype='local', dtype=dtype, lid_strides=dict(sorted(six.iteritems(lid_strides))), gid_strides=dict(sorted(six.iteritems(gid_strides))), variable=name, - count_granularity=CountGranularity.SUBGROUP)] = 1 + count_granularity=self.local_mem_count_granularity, + kernel_name=self.knl.name)] = self.one - return sub_map + return self.new_poly_map(count_map) def map_variable(self, expr): return self.count_var_access( @@ -1084,7 +1234,7 @@ class LocalMemAccessCounter(MemAccessCounter): # {{{ GlobalMemAccessCounter -class GlobalMemAccessCounter(MemAccessCounter): +class GlobalMemAccessCounter(MemAccessCounterBase): def map_variable(self, expr): name = expr.name @@ -1092,17 +1242,18 @@ class GlobalMemAccessCounter(MemAccessCounter): array = self.knl.arg_dict[name] else: # this is a temporary variable - return ToCountMap() + return self.new_zero_poly_map() if not isinstance(array, lp.ArrayArg): # this array is not in global memory - return ToCountMap() + return self.new_zero_poly_map() - return ToCountMap({MemAccess(mtype='global', - dtype=self.type_inf(expr), lid_strides={}, - gid_strides={}, variable=name, - count_granularity=CountGranularity.WORKITEM): 1} - ) + self.rec(expr.index) + return self.new_poly_map({MemAccess(mtype='global', + dtype=self.type_inf(expr), lid_strides={}, + gid_strides={}, variable=name, + count_granularity=CountGranularity.WORKITEM, + kernel_name=self.knl.name): self.one} + ) + self.rec(expr.index) def map_subscript(self, expr): name = expr.aggregate.name @@ -1128,19 +1279,28 @@ class GlobalMemAccessCounter(MemAccessCounter): lid_strides, gid_strides = _get_lid_and_gid_strides( self.knl, array, index_tuple) - count_granularity = CountGranularity.WORKITEM if ( - 0 in lid_strides and lid_strides[0] != 0 - ) else CountGranularity.SUBGROUP + # FIXME: Revert to subgroup + global_access_count_granularity = CountGranularity.WORKITEM - return ToCountMap({MemAccess( + # Account for broadcasts once per subgroup + count_granularity = CountGranularity.WORKITEM if ( + # if the stride in lid.0 is known + 0 in lid_strides + and + # it is nonzero + lid_strides[0] != 0 + ) else global_access_count_granularity + + return self.new_poly_map({MemAccess( mtype='global', dtype=self.type_inf(expr), lid_strides=dict(sorted(six.iteritems(lid_strides))), gid_strides=dict(sorted(six.iteritems(gid_strides))), variable=name, variable_tag=var_tag, - count_granularity=count_granularity - ): 1} + count_granularity=count_granularity, + kernel_name=self.knl.name, + ): self.one} ) + self.rec(expr.index_tuple) # }}} @@ -1216,7 +1376,9 @@ class AccessFootprintGatherer(CombineMapper): # {{{ count def add_assumptions_guard(kernel, pwqpolynomial): - return GuardedPwQPolynomial(pwqpolynomial, kernel.assumptions) + return GuardedPwQPolynomial( + pwqpolynomial, + kernel.assumptions.align_params(pwqpolynomial.space)) def count(kernel, set, space=None): @@ -1319,7 +1481,7 @@ def count(kernel, set, space=None): def get_unused_hw_axes_factor(knl, callables_table, insn, - disregard_local_axes, space=None): + disregard_local_axes): # FIXME: Multi-kernel support gsize, lsize = knl.get_grid_size_upper_bounds(callables_table) @@ -1338,12 +1500,12 @@ def get_unused_hw_axes_factor(knl, callables_table, insn, g_used.add(tag.axis) def mult_grid_factor(used_axes, size): - result = 1 + result = get_kernel_zero_pwqpolynomial(knl) + 1 + for iaxis, size in enumerate(size): if iaxis not in used_axes: if not isinstance(size, int): - if space is not None: - size = size.align_params(space) + size = size.align_params(result.space) size = isl.PwQPolynomial.from_pw_aff(size) @@ -1359,6 +1521,16 @@ def get_unused_hw_axes_factor(knl, callables_table, insn, return add_assumptions_guard(knl, result) +def count_inames_domain(knl, inames): + space = get_kernel_parameter_space(knl) + if not inames: + return get_kernel_zero_pwqpolynomial(knl) + 1 + + inames_domain = knl.get_inames_domain(inames) + domain = inames_domain.project_out_except(inames, [dim_type.set]) + return count(knl, domain, space=space) + + def count_insn_runs(knl, callables_table, insn, count_redundant_work, disregard_local_axes=False): @@ -1370,18 +1542,11 @@ def count_insn_runs(knl, callables_table, insn, count_redundant_work, [iname for iname in insn_inames if not knl.iname_tags_of_type(iname, LocalIndexTag)]) - inames_domain = knl.get_inames_domain(insn_inames) - domain = (inames_domain.project_out_except( - insn_inames, [dim_type.set])) - - space = isl.Space.create_from_names(isl.DEFAULT_CONTEXT, - set=[], params=knl.outer_params()) - - c = count(knl, domain, space=space) + c = count_inames_domain(knl, insn_inames) if count_redundant_work: unused_fac = get_unused_hw_axes_factor(knl, callables_table, - insn, disregard_local_axes=disregard_local_axes, space=space) + insn, disregard_local_axes=disregard_local_axes) return c * unused_fac else: return c @@ -1412,7 +1577,8 @@ def _get_insn_count(knl, callables_table, insn_id, subgroup_size, if count_granularity == CountGranularity.WORKGROUP: return ct_disregard_local elif count_granularity == CountGranularity.SUBGROUP: - # get the group size + # {{{ compute workgroup_size + from loopy.symbolic import aff_to_expr _, local_size = knl.get_grid_size_upper_bounds(callables_table) workgroup_size = 1 @@ -1425,15 +1591,18 @@ def _get_insn_count(knl, callables_table, insn_id, subgroup_size, % (CountGranularity.SUBGROUP, local_size)) workgroup_size *= s + # }}} + warn_with_kernel(knl, "insn_count_subgroups_upper_bound", "get_insn_count: when counting instruction %s with " "count_granularity=%s, using upper bound for work-group size " "(%d work-items) to compute sub-groups per work-group. When " - "multiple device programs present, actual sub-group count may be" + "multiple device programs present, actual sub-group count may be " "lower." % (insn_id, CountGranularity.SUBGROUP, workgroup_size)) from pytools import div_ceil return ct_disregard_local*div_ceil(workgroup_size, subgroup_size) + else: # this should not happen since this is enforced in Op/MemAccess raise ValueError("get_insn_count: count_granularity '%s' is" @@ -1445,9 +1614,9 @@ def _get_insn_count(knl, callables_table, insn_id, subgroup_size, # {{{ get_op_map -def get_op_map_for_single_kernel(knl, callables_table, - numpy_types=True, count_redundant_work=False, - count_within_subscripts=True, subgroup_size=None): +def _get_op_map_for_single_kernel(knl, callables_table, + count_redundant_work, + count_within_subscripts, subgroup_size): if not knl.options.ignore_boostable_into: raise LoopyError("Kernel '%s': Using operation counting requires the option " @@ -1455,9 +1624,15 @@ def get_op_map_for_single_kernel(knl, callables_table, subgroup_size = _process_subgroup_size(knl, subgroup_size) - op_map = ToCountMap() - op_counter = ExpressionOpCounter(knl, callables_table, + kernel_rec = partial(_get_op_map_for_single_kernel, + callables_table=callables_table, + count_redundant_work=count_redundant_work, + count_within_subscripts=count_within_subscripts, + subgroup_size=subgroup_size) + + op_counter = ExpressionOpCounter(knl, callables_table, kernel_rec, count_within_subscripts) + op_map = op_counter.new_zero_poly_map() from loopy.kernel.instruction import ( CallInstruction, CInstruction, Assignment, @@ -1465,14 +1640,12 @@ def get_op_map_for_single_kernel(knl, callables_table, for insn in knl.instructions: if isinstance(insn, (CallInstruction, CInstruction, Assignment)): - ops = op_counter(insn.assignee) + op_counter(insn.expression) + ops = op_counter(insn.assignees) + op_counter(insn.expression) for key, val in six.iteritems(ops.count_map): - op_map = ( - op_map - + ToCountMap({key: val}) - * _get_insn_count(knl, callables_table, insn.id, + count = _get_insn_count(knl, callables_table, insn.id, subgroup_size, count_redundant_work, - key.count_granularity)) + key.count_granularity) + op_map = op_map + ToCountMap({key: val}) * count elif isinstance(insn, (NoOpInstruction, BarrierInstruction)): pass @@ -1480,15 +1653,7 @@ def get_op_map_for_single_kernel(knl, callables_table, raise NotImplementedError("unexpected instruction item type: '%s'" % type(insn).__name__) - if numpy_types: - return ToCountMap( - init_dict=dict( - (op.copy(dtype=op.dtype.numpy_dtype), ct) - for op, ct in six.iteritems(op_map.count_map)), - val_type=op_map.val_type - ) - else: - return op_map + return op_map def get_op_map(program, numpy_types=True, count_redundant_work=False, @@ -1498,10 +1663,6 @@ def get_op_map(program, numpy_types=True, count_redundant_work=False, :arg knl: A :class:`loopy.LoopKernel` whose operations are to be counted. - :arg numpy_types: A :class:`bool` specifying whether the types in the - returned mapping should be numpy types instead of - :class:`loopy.LoopyType`. - :arg count_redundant_work: Based on usage of hardware axes or other specifics, a kernel may perform work redundantly. This :class:`bool` flag indicates whether this work should be included in the count. @@ -1519,7 +1680,7 @@ def get_op_map(program, numpy_types=True, count_redundant_work=False, specifies that it should only be counted once per sub-group. If set to *None* an attempt to find the sub-group size using the device will be made, if this fails an error will be raised. If a :class:`str` - ``'guess'`` is passed as the subgroup_size, get_mem_access_map will + ``'guess'`` is passed as the subgroup_size, :func:`get_op_map` will attempt to find the sub-group size using the device and, if unsuccessful, will make a wild guess. @@ -1556,34 +1717,28 @@ def get_op_map(program, numpy_types=True, count_redundant_work=False, program = make_program(program) from loopy.preprocess import preprocess_program, infer_unknown_types - program = infer_unknown_types(program, expect_completion=True) program = preprocess_program(program) - op_map = ToCountMap() - - callables_count = ( - program.callables_table.callables_count) - - for func_id, in_knl_callable in program.callables_table.items(): - if isinstance(in_knl_callable, CallableKernel): - knl = in_knl_callable.subkernel - knl_op_map = get_op_map_for_single_kernel(knl, - program.callables_table, numpy_types, count_redundant_work, - count_within_subscripts, subgroup_size) + # Ordering restriction: preprocess might insert arguments to + # make strides valid. Those also need to go through type inference. + program = infer_unknown_types(program, expect_completion=True) - for i in range(callables_count[func_id]): - op_map += knl_op_map - elif isinstance(in_knl_callable, ScalarCallable): - pass - else: - raise NotImplementedError("Unknown callabke types %s." % ( - type(in_knl_callable).__name__)) + if numpy_types is not None: + from warnings import warn + warn("numpy_types is being ignored and will be removed in 2020.", + DeprecationWarning, stacklevel=2) - return op_map + return _get_op_map_for_single_kernel( + program[program.name], program.callables_table, + count_redundant_work=count_redundant_work, + count_within_subscripts=count_within_subscripts, + subgroup_size=subgroup_size) # }}} +# {{{ subgoup size finding + def _find_subgroup_size_for_knl(knl): from loopy.target.pyopencl import PyOpenCLTarget if isinstance(knl.target, PyOpenCLTarget) and knl.target.device is not None: @@ -1635,11 +1790,13 @@ def _process_subgroup_size(knl, subgroup_size_requested): "must be integer, 'guess', or, if you're feeling " "lucky, None." % (subgroup_size_requested)) +# }}} + # {{{ get_mem_access_map -def get_mem_access_map_for_single_kernel(knl, callables_table, - numpy_types=True, count_redundant_work=False, subgroup_size=None): +def _get_mem_access_map_for_single_kernel(knl, callables_table, + count_redundant_work, subgroup_size): if not knl.options.ignore_boostable_into: raise LoopyError("Kernel '%s': Using operation counting requires the option " @@ -1647,9 +1804,16 @@ def get_mem_access_map_for_single_kernel(knl, callables_table, subgroup_size = _process_subgroup_size(knl, subgroup_size) - access_map = ToCountMap() - access_counter_g = GlobalMemAccessCounter(knl, callables_table) - access_counter_l = LocalMemAccessCounter(knl, callables_table) + kernel_rec = partial(_get_mem_access_map_for_single_kernel, + callables_table=callables_table, + count_redundant_work=count_redundant_work, + subgroup_size=subgroup_size) + + access_counter_g = GlobalMemAccessCounter( + knl, callables_table, kernel_rec) + access_counter_l = LocalMemAccessCounter( + knl, callables_table, kernel_rec) + access_map = access_counter_g.new_zero_poly_map() from loopy.kernel.instruction import ( CallInstruction, CInstruction, Assignment, @@ -1657,62 +1821,39 @@ def get_mem_access_map_for_single_kernel(knl, callables_table, for insn in knl.instructions: if isinstance(insn, (CallInstruction, CInstruction, Assignment)): - access_expr = ( - access_counter_g(insn.expression) - + access_counter_l(insn.expression) - ).with_set_attributes(direction="load") - - access_assignee = ( - access_counter_g(insn.assignee) - + access_counter_l(insn.assignee) - ).with_set_attributes(direction="store") - - for key, val in six.iteritems(access_expr.count_map): - - access_map = ( - access_map - + ToCountMap({key: val}) - * _get_insn_count(knl, callables_table, insn.id, - subgroup_size, count_redundant_work, - key.count_granularity)) - - for key, val in six.iteritems(access_assignee.count_map): - - access_map = ( - access_map - + ToCountMap({key: val}) - * _get_insn_count(knl, callables_table, insn.id, + insn_access_map = ( + access_counter_g(insn.expression) + + access_counter_l(insn.expression) + ).with_set_attributes(direction="load") + for assignee in insn.assignees: + insn_access_map = insn_access_map + ( + access_counter_g(insn.assignee) + + access_counter_l(insn.assignee) + ).with_set_attributes(direction="store") + + for key, val in six.iteritems(insn_access_map.count_map): + count = _get_insn_count(knl, callables_table, insn.id, subgroup_size, count_redundant_work, - key.count_granularity)) + key.count_granularity) + access_map = access_map + ToCountMap({key: val}) * count elif isinstance(insn, (NoOpInstruction, BarrierInstruction)): pass + else: raise NotImplementedError("unexpected instruction item type: '%s'" % type(insn).__name__) - if numpy_types: - return ToCountMap( - init_dict=dict( - (mem_access.copy(dtype=mem_access.dtype.numpy_dtype), ct) - for mem_access, ct in six.iteritems(access_map.count_map)), - val_type=access_map.val_type - ) - else: - return access_map + return access_map -def get_mem_access_map(program, numpy_types=True, count_redundant_work=False, +def get_mem_access_map(program, numpy_types=None, count_redundant_work=False, subgroup_size=None): """Count the number of memory accesses in a loopy kernel. :arg knl: A :class:`loopy.LoopKernel` whose memory accesses are to be counted. - :arg numpy_types: A :class:`bool` specifying whether the types in the - returned mapping should be numpy types instead of - :class:`loopy.LoopyType`. - :arg count_redundant_work: Based on usage of hardware axes or other specifics, a kernel may perform work redundantly. This :class:`bool` flag indicates whether this work should be included in the count. @@ -1790,62 +1931,46 @@ def get_mem_access_map(program, numpy_types=True, count_redundant_work=False, """ from loopy.preprocess import preprocess_program, infer_unknown_types - program = infer_unknown_types(program, expect_completion=True) program = preprocess_program(program) + # Ordering restriction: preprocess might insert arguments to + # make strides valid. Those also need to go through type inference. + program = infer_unknown_types(program, expect_completion=True) - access_map = ToCountMap() - - callables_count = program.callables_table.callables_count - - for func_id, in_knl_callable in program.callables_table.items(): - if isinstance(in_knl_callable, CallableKernel): - knl = in_knl_callable.subkernel - knl_access_map = get_mem_access_map_for_single_kernel(knl, - program.callables_table, numpy_types, - count_redundant_work, subgroup_size) - - # FIXME: didn't see any easy way to multiply - for i in range(callables_count[func_id]): - access_map += knl_access_map - elif isinstance(in_knl_callable, ScalarCallable): - pass - else: - raise NotImplementedError("Unknown callabke types %s." % ( - type(in_knl_callable).__name__)) + if numpy_types is not None: + from warnings import warn + warn("numpy_types is being ignored and will be removed in 2020.", + DeprecationWarning, stacklevel=2) - return access_map + return _get_mem_access_map_for_single_kernel( + program[program.name], program.callables_table, + count_redundant_work=count_redundant_work, + subgroup_size=subgroup_size) # }}} # {{{ get_synchronization_map -def get_synchronization_map_for_single_kernel(knl, callables_table, +def _get_synchronization_map_for_single_kernel(knl, callables_table, subgroup_size=None): if not knl.options.ignore_boostable_into: raise LoopyError("Kernel '%s': Using operation counting requires the option " "ignore_boostable_into to be set." % knl.name) + knl = lp.get_one_scheduled_kernel(knl, callables_table) + from loopy.schedule import (EnterLoop, LeaveLoop, Barrier, CallKernel, ReturnFromKernel, RunInstruction) - from operator import mul - knl = lp.get_one_scheduled_kernel(knl, callables_table) - iname_list = [] - result = ToCountMap() + kernel_rec = partial(_get_synchronization_map_for_single_kernel, + callables_table=callables_table, + subgroup_size=subgroup_size) - one = isl.PwQPolynomial('{ 1 }') + sync_counter = CounterBase(knl, callables_table, kernel_rec) + sync_map = sync_counter.new_zero_poly_map() - def get_count_poly(iname_list): - if iname_list: - ct = (count(knl, ( - knl.get_inames_domain(iname_list). - project_out_except(iname_list, [dim_type.set]) - )), ) - return reduce(mul, ct) - else: - return one + iname_list = [] for sched_item in knl.schedule: if isinstance(sched_item, EnterLoop): @@ -1856,22 +1981,27 @@ def get_synchronization_map_for_single_kernel(knl, callables_table, iname_list.pop() elif isinstance(sched_item, Barrier): - result = result + ToCountMap({"barrier_%s" % - sched_item.synchronization_kind: - get_count_poly(iname_list)}) + sync_map = sync_map + ToCountMap( + {Sync( + "barrier_%s" % sched_item.synchronization_kind, + knl.name): count_inames_domain(knl, frozenset(iname_list))}) + + elif isinstance(sched_item, RunInstruction): + pass elif isinstance(sched_item, CallKernel): - result = result + ToCountMap( - {"kernel_launch": get_count_poly(iname_list)}) + sync_map = sync_map + ToCountMap( + {Sync("kernel_launch", knl.name): + count_inames_domain(knl, frozenset(iname_list))}) - elif isinstance(sched_item, (ReturnFromKernel, RunInstruction)): + elif isinstance(sched_item, ReturnFromKernel): pass else: raise LoopyError("unexpected schedule item: %s" % type(sched_item).__name__) - return result + return sync_map def get_synchronization_map(program, subgroup_size=None): @@ -1913,45 +2043,21 @@ def get_synchronization_map(program, subgroup_size=None): from loopy.preprocess import preprocess_program, infer_unknown_types - program = infer_unknown_types(program, expect_completion=True) program = preprocess_program(program) + # Ordering restriction: preprocess might insert arguments to + # make strides valid. Those also need to go through type inference. + program = infer_unknown_types(program, expect_completion=True) - sync_map = ToCountMap() - callables_count = program.callables_table.callables_count - - for func_id, in_knl_callable in program.callables_table.items(): - if isinstance(in_knl_callable, CallableKernel): - knl = in_knl_callable.subkernel - knl_sync_map = get_synchronization_map_for_single_kernel(knl, - program.callables_table, subgroup_size) - - # FIXME: didn't see any easy way to multiply - for i in range(callables_count[func_id]): - sync_map += knl_sync_map - elif isinstance(in_knl_callable, ScalarCallable): - pass - else: - raise NotImplementedError("Unknown callabke types %s." % ( - type(in_knl_callable).__name__)) - - return sync_map + return _get_synchronization_map_for_single_kernel( + program[program.name], program.callables_table, + subgroup_size=subgroup_size) # }}} # {{{ gather_access_footprints -def gather_access_footprints_for_single_kernel(kernel, ignore_uncountable=False): - """Return a dictionary mapping ``(var_name, direction)`` to - :class:`islpy.Set` instances capturing which indices of each the array - *var_name* are read/written (where *direction* is either ``read`` or - ``write``. - - :arg ignore_uncountable: If *False*, an error will be raised for accesses - on which the footprint cannot be determined (e.g. data-dependent or - nonlinear indices) - """ - +def _gather_access_footprints_for_single_kernel(kernel, ignore_uncountable): write_footprints = [] read_footprints = [] @@ -1978,6 +2084,16 @@ def gather_access_footprints_for_single_kernel(kernel, ignore_uncountable=False) def gather_access_footprints(program, ignore_uncountable=False): + """Return a dictionary mapping ``(var_name, direction)`` to + :class:`islpy.Set` instances capturing which indices of each the array + *var_name* are read/written (where *direction* is either ``read`` or + ``write``. + + :arg ignore_uncountable: If *False*, an error will be raised for accesses + on which the footprint cannot be determined (e.g. data-dependent or + nonlinear indices) + """ + # FIMXE: works only for one callable kernel till now. if len([in_knl_callable for in_knl_callable in program.callables_table.values() if isinstance(in_knl_callable, @@ -1987,31 +2103,16 @@ def gather_access_footprints(program, ignore_uncountable=False): from loopy.preprocess import preprocess_program, infer_unknown_types - program = infer_unknown_types(program, expect_completion=True) program = preprocess_program(program) + # Ordering restriction: preprocess might insert arguments to + # make strides valid. Those also need to go through type inference. + program = infer_unknown_types(program, expect_completion=True) write_footprints = [] read_footprints = [] - callables_count = program.callables_table.callables_count - - for func_id, in_knl_callable in program.callables_table.items(): - if isinstance(in_knl_callable, CallableKernel): - knl = in_knl_callable.subkernel - knl_write_footprints, knl_read_footprints = ( - gather_access_footprints_for_single_kernel(knl, - ignore_uncountable)) - - # FIXME: didn't see any easy way to multiply - for i in range(callables_count[func_id]): - write_footprints.extend(knl_write_footprints) - read_footprints.extend(knl_read_footprints) - - elif isinstance(in_knl_callable, ScalarCallable): - pass - else: - raise NotImplementedError("Unknown callabke types %s." % ( - type(in_knl_callable).__name__)) + write_footprints, read_footprints = _gather_access_footprints_for_single_kernel( + program[program.name], ignore_uncountable) write_footprints = AccessFootprintGatherer.combine(write_footprints) read_footprints = AccessFootprintGatherer.combine(read_footprints) diff --git a/test/test_statistics.py b/test/test_statistics.py index 41a88b386..cadca9fc1 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -218,16 +218,25 @@ def test_op_counter_bitwise(): m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - i32add = op_map[lp.Op(np.int32, 'add', CG.SUBGROUP)].eval_with_dict(params) - i32bw = op_map[lp.Op(np.int32, 'bw', CG.SUBGROUP)].eval_with_dict(params) - i64bw = op_map[lp.Op(np.dtype(np.int64), 'bw', CG.SUBGROUP) - ].eval_with_dict(params) - i64mul = op_map[lp.Op(np.dtype(np.int64), 'mul', CG.SUBGROUP) - ].eval_with_dict(params) - i64add = op_map[lp.Op(np.dtype(np.int64), 'add', CG.SUBGROUP) - ].eval_with_dict(params) - i64shift = op_map[lp.Op(np.dtype(np.int64), 'shift', CG.SUBGROUP) - ].eval_with_dict(params) + print(op_map) + i32add = op_map[ + lp.Op(np.int32, 'add', CG.SUBGROUP, 'bitwise') + ].eval_with_dict(params) + i32bw = op_map[ + lp.Op(np.int32, 'bw', CG.SUBGROUP, 'bitwise') + ].eval_with_dict(params) + i64bw = op_map[ + lp.Op(np.dtype(np.int64), 'bw', CG.SUBGROUP, 'bitwise') + ].eval_with_dict(params) + i64mul = op_map[ + lp.Op(np.dtype(np.int64), 'mul', CG.SUBGROUP, 'bitwise') + ].eval_with_dict(params) + i64add = op_map[ + lp.Op(np.dtype(np.int64), 'add', CG.SUBGROUP, 'bitwise') + ].eval_with_dict(params) + i64shift = op_map[ + lp.Op(np.dtype(np.int64), 'shift', CG.SUBGROUP, 'bitwise') + ].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert i32add == n*m*ell*n_subgroups assert i32bw == 2*n*m*ell*n_subgroups @@ -922,11 +931,10 @@ def test_barrier_counter_nobarriers(): ell = 128 params = {'n': n, 'm': m, 'ell': ell} assert len(sync_map) == 1 - assert sync_map["kernel_launch"].eval_with_dict(params) == 1 + assert sync_map.filter_by(kind="kernel_launch").eval_and_sum(params) == 1 def test_barrier_counter_barriers(): - knl = lp.make_kernel( "[n,m,ell] -> {[i,k,j]: 0<=i<50 and 1<=k<98 and 0<=j<10}", [ @@ -948,10 +956,25 @@ def test_barrier_counter_barriers(): m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - barrier_count = sync_map["barrier_local"].eval_with_dict(params) + barrier_count = sync_map.filter_by(kind="barrier_local").eval_and_sum(params) assert barrier_count == 50*10*2 +def test_barrier_count_single(): + knl = lp.make_kernel( + "{[i]: 0<=i<128}", + """ + <> c[i] = 15*i {id=yoink} + c[i+1] = c[i] {dep=yoink} + """) + + knl = lp.tag_inames(knl, {"i": "l.0"}) + sync_map = lp.get_synchronization_map(knl) + print(sync_map) + barrier_count = sync_map.filter_by(kind="barrier_local").eval_and_sum() + assert barrier_count == 1 + + def test_all_counters_parallel_matmul(): bsize = 16 knl = lp.make_kernel( @@ -978,8 +1001,8 @@ def test_all_counters_parallel_matmul(): sync_map = lp.get_synchronization_map(knl) assert len(sync_map) == 2 - assert sync_map["kernel_launch"].eval_with_dict(params) == 1 - assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/bsize + assert sync_map.filter_by(kind="kernel_launch").eval_and_sum(params) == 1 + assert sync_map.filter_by(kind="barrier_local").eval_and_sum(params) == 2*m/bsize op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) f32mul = op_map[ @@ -1096,9 +1119,8 @@ def test_floor_div_coefficient_collector(): n_subgroups = n_workgroups*subgroups_per_group # count local f32 accesses - f32_local = lp.get_mem_access_map( - knl, count_redundant_work=True, subgroup_size=SGS - ).filter_by(dtype=[np.float32], mtype=["local"]).eval_and_sum(params) + m = lp.get_mem_access_map(knl, count_redundant_work=True, subgroup_size=SGS) + f32_local = m.filter_by(dtype=[np.float32], mtype=["local"]).eval_and_sum(params) # (count-per-sub-group)*n_subgroups assert f32_local == 2*(rept+1)*n_subgroups @@ -1176,7 +1198,7 @@ def test_gather_access_footprint(): fp = gather_access_footprints(knl) for key, footprint in six.iteritems(fp): - print(key, count(knl, footprint)) + print(key, count(knl.root_kernel, footprint)) def test_gather_access_footprint_2(): @@ -1191,8 +1213,8 @@ def test_gather_access_footprint_2(): params = {"n": 200} for key, footprint in six.iteritems(fp): - assert count(knl, footprint).eval_with_dict(params) == 200 - print(key, count(knl, footprint)) + assert count(knl.root_kernel, footprint).eval_with_dict(params) == 200 + print(key, count(knl.root_kernel, footprint)) def test_summations_and_filters(): @@ -1316,8 +1338,8 @@ def test_strided_footprint(): x_l_foot = footprints[('x', 'read')] from loopy.statistics import count - num = count(knl, x_l_foot).eval_with_dict(param_dict) - denom = count(knl, x_l_foot.remove_divs()).eval_with_dict(param_dict) + num = count(knl.root_kernel, x_l_foot).eval_with_dict(param_dict) + denom = count(knl.root_kernel, x_l_foot.remove_divs()).eval_with_dict(param_dict) assert 2*num < denom -- GitLab