diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index b97639c910f26b1390314c7364643de87ef9889d..69767d5e689543ab5e2f6641c6697e457d7a0b2b 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -27,12 +27,14 @@ THE SOFTWARE. import numpy as np from pymbolic.mapper import CSECachingMapperMixin +from pymbolic.primitives import Slice, Variable, Subscript from loopy.tools import intern_frozenset_of_ids -from loopy.symbolic import IdentityMapper, WalkMapper, CombineMapper +from loopy.symbolic import IdentityMapper, WalkMapper, CombineMapper, SubArrayRef from loopy.kernel.data import ( InstructionBase, MultiAssignmentBase, Assignment, SubstitutionRule) +from loopy.kernel.instruction import CInstruction, _DataObliviousInstruction from loopy.diagnostic import LoopyError, warn_with_kernel import islpy as isl from islpy import dim_type @@ -498,7 +500,7 @@ def parse_insn(groups, insn_options): if isinstance(inner_lhs_i, Lookup): inner_lhs_i = inner_lhs_i.aggregate - from loopy.symbolic import LinearSubscript, SubArrayRef + from loopy.symbolic import LinearSubscript if isinstance(inner_lhs_i, Variable): assignee_names.append(inner_lhs_i.name) elif isinstance(inner_lhs_i, (Subscript, LinearSubscript)): @@ -2001,6 +2003,119 @@ def scope_functions(kernel): # }}} +# {{{ slice to sub array ref + +def get_slice_params(expr, domain_length): + """ + Either reads the params from the slice or initiates the value to defaults. + """ + start, stop, step = expr.start, expr.stop, expr.step + + if start is None: + start = 0 + + if stop is None: + stop = domain_length + + if step is None: + step = 1 + + return start, stop, step + + +class SliceToInameReplacer(IdentityMapper): + """ + Mapper that converts slices to instances of :class:`SubArrayRef`. + """ + def __init__(self, knl, var_name_gen): + self.var_name_gen = var_name_gen + self.knl = knl + self.iname_domains = {} + + def map_subscript(self, expr): + updated_index = [] + swept_inames = [] + for i, index in enumerate(expr.index_tuple): + if isinstance(index, Slice): + unique_var_name = self.var_name_gen(based_on="islice") + if expr.aggregate.name in self.knl.arg_dict: + domain_length = self.knl.arg_dict[expr.aggregate.name].shape[i] + elif expr.aggregate.name in self.knl.temporary_variables: + domain_length = self.knl.temporary_variables[ + expr.aggregate.name].shape[i] + else: + raise LoopyError("Slice notation is only supported for " + "variables whose shapes are known at creation time " + "-- maybe add the shape for the sliced argument.") + start, stop, step = get_slice_params( + index, domain_length) + self.iname_domains[unique_var_name] = (start, stop, step) + + updated_index.append(step*Variable(unique_var_name)) + swept_inames.append(Variable(unique_var_name)) + else: + updated_index.append(index) + + if swept_inames: + return SubArrayRef(tuple(swept_inames), Subscript( + self.rec(expr.aggregate), + self.rec(tuple(updated_index)))) + else: + return IdentityMapper.map_subscript(self, expr) + + def get_iname_domain_as_isl_set(self): + """ + Returns the extra domain constraints imposed by the slice inames. + """ + if not self.iname_domains: + return None + + ctx = self.knl.isl_context + space = isl.Space.create_from_names(ctx, + set=list(self.iname_domains.keys())) + iname_set = isl.BasicSet.universe(space) + + for iname, (start, stop, step) in self.iname_domains.items(): + iname_set = (iname_set + .add_constraint(isl.Constraint.ineq_from_names(space, {1: + -start, iname: step})) + .add_constraint(isl.Constraint.ineq_from_names(space, {1: + stop-1, iname: -step}))) + + return iname_set + + +def realize_slices_as_sub_array_refs(kernel): + """ + Transformation that returns a kernel with the instances of + :class:`pymbolic.primitives.Slice` to `loopy.symbolic.SubArrayRef` + """ + unique_var_name_generator = kernel.get_var_name_generator() + slice_replacer = SliceToInameReplacer(kernel, unique_var_name_generator) + new_insns = [] + + for insn in kernel.instructions: + if isinstance(insn, (MultiAssignmentBase, CInstruction)): + new_expr = slice_replacer(insn.expression) + new_insns.append(insn.copy(expression=new_expr)) + elif isinstance(insn, _DataObliviousInstruction): + new_insns.append(insn) + else: + raise NotImplementedError("parse_slices not implemented for %s" % + type(insn)) + + slice_iname_domains = slice_replacer.get_iname_domain_as_isl_set() + + if slice_iname_domains: + d1, d2 = isl.align_two(kernel.domains[0], slice_iname_domains) + return kernel.copy(domains=[d1 & d2], + instructions=new_insns) + else: + return kernel.copy(instructions=new_insns) + +# }}} + + # {{{ kernel creation top-level def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): @@ -2298,6 +2413,10 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): check_for_nonexistent_iname_deps(knl) knl = create_temporaries(knl, default_order) + + # Convert slices to iname domains + knl = realize_slices_as_sub_array_refs(knl) + # ------------------------------------------------------------------------- # Ordering dependency: # -------------------------------------------------------------------------