diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index 0d15b9b4e550386090b1e7d54ebb167767609433..cf6e92771c2776ed98add12fb165eb0741624b36 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -123,6 +123,81 @@ class ArrayArgDescriptor(ImmutableRecord): update_persistent_hash = update_persistent_hash + +def get_arg_descriptor_for_expression(kernel, expr): + """ + :returns: a :class:`ArrayArgDescriptor` or a :class:`ValueArgDescriptor` + describing the argument expression *expr* which occurs + in a call in the code of *kernel*. + """ + from pymbolic.primitives import Variable + from loopy.symbolic import (SubArrayRef, pw_aff_to_expr, + SweptInameStrideCollector) + from loopy.kernel.data import TemporaryVariable, ArrayArg + + if isinstance(expr, SubArrayRef): + name = expr.subscript.aggregate.name + arg = kernel.get_var_descriptor(name) + + if not isinstance(arg, (TemporaryVariable, ArrayArg)): + raise LoopyError("unsupported argument type " + "'%s' of '%s' in call statement" + % (type(arg).__name__, expr.name)) + + aspace = arg.address_space + + from loopy.kernel.array import FixedStrideArrayDimTag as DimTag + from loopy.symbolic import simplify_using_aff + sub_dim_tags = [] + sub_shape = [] + + # FIXME This blindly assumes that dim_tag has a stride and + # will not work for non-stride dim tags (e.g. vec or sep). + + # FIXME: This will almost always be nonlinear--when does this + # actually help? Maybe the + linearized_index = simplify_using_aff( + kernel, + sum( + dim_tag.stride*iname for dim_tag, iname in + zip(arg.dim_tags, expr.subscript.index_tuple))) + + strides_as_dict = SweptInameStrideCollector( + tuple(iname.name for iname in expr.swept_inames) + )(linearized_index) + sub_dim_tags = tuple( + DimTag(strides_as_dict[iname]) for iname in expr.swept_inames) + sub_shape = tuple( + pw_aff_to_expr( + kernel.get_iname_bounds(iname.name).upper_bound_pw_aff)+1 + for iname in expr.swept_inames) + if expr.swept_inames == (): + sub_shape = (1, ) + sub_dim_tags = (DimTag(1),) + + return ArrayArgDescriptor( + address_space=aspace, + dim_tags=sub_dim_tags, + shape=sub_shape) + + elif isinstance(expr, Variable): + arg = kernel.get_var_descriptor(expr.name) + + if isinstance(arg, (TemporaryVariable, ArrayArg)): + raise LoopyError("may not pass entire array " + "'%s' in call statement in kernel '%s'" + % (expr.name, kernel.name)) + + elif isinstance(arg, ValueArg): + return ValueArgDescriptor() + else: + raise LoopyError("unsupported argument type " + "'%s' of '%s' in call statement" + % (type(arg).__name__, expr.name)) + + else: + return ValueArgDescriptor() + # }}} @@ -601,6 +676,11 @@ class CallableKernel(InKernelCallable): # {{{ map the arg_descrs so that all the variables are from the callees # perspective + # FIXME: This is ill-formed, because par can be an expression, e.g. + # 2*i+2 or 2*(i+1). A key feature of expression is that structural + # equality and semantic equality are not the same, so even if the + # SubstitutionMapper allowed non-variables, it would have to solve the + # (considerable) problem of expression equivalence. substs = {} for arg, par in zip(self.subkernel.args, expr.parameters): if isinstance(arg, ValueArg): diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 540c77b12cec90f8c161b21ce62d989f2d98289c..3ae6a240acd8cb7977798fbf7c686931f7355380 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -22,6 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import six from six.moves import intern from pytools import ImmutableRecord, memoize_method from loopy.diagnostic import LoopyError @@ -1144,6 +1145,22 @@ class CallInstruction(MultiAssignmentBase): result += "\n" + 10*" " + "if (%s)" % " && ".join(self.predicates) return result + def arg_id_to_val(self): + """:returns: a :class:`dict` mapping argument identifiers (non-negative numbers + for positional arguments, strings for keyword args, and negative numbers + for assignees) to their respective values + """ + + from pymbolic.primitives import CallWithKwargs + arg_id_to_val = dict(enumerate(self.expression.parameters)) + if isinstance(self.expression, CallWithKwargs): + for kw, val in six.iteritems(self.expression.kw_parameters): + arg_id_to_val[kw] = val + for i, arg in enumerate(self.assignees): + arg_id_to_val[-i-1] = arg + + return arg_id_to_val + @property def atomicity(self): # Function calls can impossibly be atomic, and even the result assignment diff --git a/loopy/preprocess.py b/loopy/preprocess.py index e70e6b6fe8310592e3779570db78bc4aac45309b..40d49869e9afcaf0751939c4c73a776389ae8479 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2169,48 +2169,35 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): def map_call(self, expr, expn_state, **kwargs): from pymbolic.primitives import Call, CallWithKwargs - from loopy.kernel.function_interface import ValueArgDescriptor - from loopy.symbolic import ResolvedFunction, SubArrayRef + from loopy.symbolic import ResolvedFunction if not isinstance(expr.function, ResolvedFunction): # ignore if the call is not to a ResolvedFunction return super(ArgDescrInferenceMapper, self).map_call(expr, expn_state) - if isinstance(expr, Call): - kw_parameters = {} - else: - assert isinstance(expr, CallWithKwargs) - kw_parameters = expr.kw_parameters - - # descriptors for the args and kwargs of the Call - arg_id_to_descr = dict((i, par.get_array_arg_descriptor(self.caller_kernel)) - if isinstance(par, SubArrayRef) else (i, ValueArgDescriptor()) - for i, par in tuple(enumerate(expr.parameters)) + - tuple(kw_parameters.items())) - - assignee_id_to_descr = {} + arg_id_to_val = dict(enumerate(expr.parameters)) + if isinstance(expr, CallWithKwargs): + arg_id_to_val.update(expr.kw_parameters) if 'assignees' in kwargs: # If supplied with assignees then this is a CallInstruction assignees = kwargs['assignees'] - assert isinstance(assignees, tuple) - for i, par in enumerate(assignees): - if isinstance(par, SubArrayRef): - assignee_id_to_descr[-i-1] = ( - par.get_array_arg_descriptor(self.caller_kernel)) - else: - assignee_id_to_descr[-i-1] = ValueArgDescriptor() - - # gathering all the descriptors - combined_arg_id_to_descr = arg_id_to_descr.copy() - combined_arg_id_to_descr.update(assignee_id_to_descr) + for i, arg in enumerate(assignees): + arg_id_to_val[-i-1] = arg + + from loopy.kernel.function_interface import get_arg_descriptor_for_expression + arg_id_to_descr = dict( + (arg_id, get_arg_descriptor_for_expression( + self.caller_kernel, arg)) + for arg_id, arg in six.iteritems(arg_id_to_val)) # specializing the function according to the parameter description in_knl_callable = self.callables_table[expr.function.name] new_in_knl_callable, self.callables_table = ( in_knl_callable.with_descrs( - combined_arg_id_to_descr, self.caller_kernel, + arg_id_to_descr, self.caller_kernel, self.callables_table, expr)) + self.callables_table, new_func_id = ( self.callables_table.with_callable( expr.function.function, @@ -2229,7 +2216,7 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): for child in expr.parameters), dict( (key, self.rec(val, expn_state)) - for key, val in six.iteritems(kw_parameters)) + for key, val in six.iteritems(expr.kw_parameters)) ) map_call_with_kwargs = map_call diff --git a/loopy/symbolic.py b/loopy/symbolic.py index f717a07729ac2d042b954b9066a6cad559a785c6..d98c3fdeaa5d833197b964e92c2bb0f63704a7c4 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -827,55 +827,6 @@ class SubArrayRef(p.Expression): return EvaluatorWithDeficientContext(swept_inames_to_zeros)( self.subscript) - def get_array_arg_descriptor(self, kernel): - """ - Returns the dim_tags, memory scope, shape informations of a - :class:`SubArrayRef` argument in the caller kernel packed into - :class:`ArrayArgDescriptor` for the instance of :class:`SubArrayRef` in - the given *kernel*. - """ - from loopy.kernel.function_interface import ArrayArgDescriptor - - name = self.subscript.aggregate.name - - if name in kernel.temporary_variables: - assert name not in kernel.arg_dict - arg = kernel.temporary_variables[name] - else: - assert name in kernel.arg_dict - arg = kernel.arg_dict[name] - - aspace = arg.address_space - - from loopy.kernel.array import FixedStrideArrayDimTag as DimTag - from loopy.isl_helpers import simplify_via_aff - sub_dim_tags = [] - sub_shape = [] - try: - linearized_index = simplify_via_aff( - sum(dim_tag.stride*iname for dim_tag, iname in - zip(arg.dim_tags, self.subscript.index_tuple))) - except isl.Error: - linearized_index = sum(dim_tag.stride*iname for dim_tag, iname in - zip(arg.dim_tags, self.subscript.index_tuple)) - - strides_as_dict = SweptInameStrideCollector(tuple(iname.name for iname in - self.swept_inames))(linearized_index) - sub_dim_tags = tuple( - DimTag(strides_as_dict[iname]) for iname in self.swept_inames) - sub_shape = tuple( - pw_aff_to_expr( - kernel.get_iname_bounds(iname.name).upper_bound_pw_aff)+1 - for iname in self.swept_inames) - if self.swept_inames == (): - sub_shape = (1, ) - sub_dim_tags = (DimTag(1),) - - return ArrayArgDescriptor( - address_space=aspace, - dim_tags=sub_dim_tags, - shape=sub_shape) - def __getinitargs__(self): return (self.swept_inames, self.subscript) @@ -1685,6 +1636,7 @@ def guarded_pwaff_from_expr(space, expr, vars_to_zero=None): # {{{ simplify using aff +# FIXME: redundant with simplify_via_aff def simplify_using_aff(kernel, expr): inames = get_dependencies(expr) & kernel.all_inames() diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 953ad56134edaabbd5451ae9ecb1a9c07b80c1cb..042990c77a486fcd294811fa4d15ac24c3050771 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -34,7 +34,7 @@ from loopy.kernel.instruction import (CallInstruction, MultiAssignmentBase, Assignment, CInstruction, _DataObliviousInstruction) from loopy.symbolic import IdentityMapper, SubstitutionMapper, CombineMapper from loopy.isl_helpers import simplify_via_aff -from loopy.kernel.function_interface import (get_kw_pos_association, +from loopy.kernel.function_interface import ( CallableKernel, ScalarCallable) from loopy.program import Program, ResolvedFunctionMarker from loopy.symbolic import SubArrayRef @@ -616,10 +616,10 @@ class DimChanger(IdentityMapper): def _match_caller_callee_argument_dimension_for_single_kernel( caller_knl, callee_knl): """ - Returns a copy of *caller_knl* with the instance of - :class:`loopy.kernel.function_interface.CallableKernel` addressed by - *callee_function_name* in the *caller_knl* aligned with the argument - dimesnsions required by *caller_knl*. + :returns: a copy of *caller_knl* with the instance of + :class:`loopy.kernel.function_interface.CallableKernel` addressed by + *callee_function_name* in the *caller_knl* aligned with the argument + dimensions required by *caller_knl*. """ for insn in caller_knl.instructions: if not isinstance(insn, CallInstruction) or ( @@ -628,14 +628,6 @@ def _match_caller_callee_argument_dimension_for_single_kernel( # Call to a callable kernel can only occur through a # CallInstruction. continue - # getting the caller->callee arg association - - parameters = insn.expression.parameters[:] - kw_parameters = {} - if isinstance(insn.expression, CallWithKwargs): - kw_parameters = insn.expression.kw_parameters - - assignees = insn.assignees def _shape_1_if_empty(shape): assert isinstance(shape, tuple) @@ -644,34 +636,18 @@ def _match_caller_callee_argument_dimension_for_single_kernel( else: return shape - parameter_shapes = [] - for par in parameters: - if isinstance(par, SubArrayRef): - parameter_shapes.append( - _shape_1_if_empty( - par.get_array_arg_descriptor(caller_knl).shape)) - else: - parameter_shapes.append((1, )) - - kw_to_pos, pos_to_kw = get_kw_pos_association(callee_knl) - for i in range(len(parameters), len(parameters)+len(kw_parameters)): - parameter_shapes.append(_shape_1_if_empty(kw_parameters[pos_to_kw[i]]) - .get_array_arg_descriptor(caller_knl).shape) - - # inserting the assignees at the required positions. - assignee_write_count = -1 - for i, arg in enumerate(callee_knl.args): - if arg.is_output_only: - assignee = assignees[-assignee_write_count-1] - parameter_shapes.insert(i, _shape_1_if_empty(assignee - .get_array_arg_descriptor(caller_knl).shape)) - assignee_write_count -= 1 - - callee_arg_to_desired_dim_tag = dict(zip([arg.name for arg in - callee_knl.args], parameter_shapes)) + from loopy.kernel.function_interface import ( + ArrayArgDescriptor, get_arg_descriptor_for_expression) + arg_id_to_shape = {} + for arg_id, arg in six.iteritems(insn.arg_id_to_val()): + arg_descr = get_arg_descriptor_for_expression(caller_knl, arg) + if isinstance(arg_descr, ArrayArgDescriptor): + arg_id_to_shape[arg_id] = _shape_1_if_empty(arg_descr) + dim_changer = DimChanger( callee_knl.arg_dict, - callee_arg_to_desired_dim_tag) + arg_id_to_shape) + new_callee_insns = [] for callee_insn in callee_knl.instructions: if isinstance(callee_insn, MultiAssignmentBase): @@ -686,7 +662,7 @@ def _match_caller_callee_argument_dimension_for_single_kernel( raise NotImplementedError("Unknown instruction %s." % type(insn)) - # subkernel with instructions adjusted according to the new dimensions. + # subkernel with instructions adjusted according to the new dimensions new_callee_knl = callee_knl.copy(instructions=new_callee_insns) return new_callee_knl