diff --git a/loopy/__init__.py b/loopy/__init__.py index 66ba75024f1cbd80349025991af6927346db88b6..b32ed2e0b04ad617e7338d3d24decc7562e1daf3 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -73,7 +73,7 @@ from loopy.transform.iname import ( affine_map_inames, find_unused_axis_tag, make_reduction_inames_unique, has_schedulable_iname_nesting, get_iname_duplication_options, - add_inames_to_insn, add_inames_for_unused_hw_axes) + add_inames_to_insn, add_inames_for_unused_hw_axes, map_domain) from loopy.transform.instruction import ( find_instructions, map_instructions, @@ -190,7 +190,7 @@ __all__ = [ "affine_map_inames", "find_unused_axis_tag", "make_reduction_inames_unique", "has_schedulable_iname_nesting", "get_iname_duplication_options", - "add_inames_to_insn", "add_inames_for_unused_hw_axes", + "add_inames_to_insn", "add_inames_for_unused_hw_axes", "map_domain", "add_prefetch", "change_arg_to_image", "tag_array_axes", "tag_data_axes", diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index fb5e8d781ebc3f8c806dfa7b531560f0855c98d5..d832adbd7543b97d8ecb3eb3c133a8e640d78bda 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -1783,6 +1783,265 @@ def add_inames_to_insn(kernel, inames, insn_match): # }}} +# {{{ map_domain + +class _MapDomainMapper(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, within, new_inames, substitutions): + super(_MapDomainMapper, self).__init__(rule_mapping_context) + + self.within = within + + self.old_inames = frozenset(substitutions) + self.new_inames = new_inames + + self.substitutions = substitutions + + def map_reduction(self, expr, expn_state): + red_overlap = frozenset(expr.inames) & self.old_inames + arg_ctx_overlap = frozenset(expn_state.arg_context) & self.old_inames + if (red_overlap + and self.within( + expn_state.kernel, + expn_state.instruction)): + if len(red_overlap) != len(self.old_inames): + raise LoopyError("reduction '%s' involves a part " + "of the map domain inames. Reductions must " + "either involve all or none of the map domain " + "inames." % str(expr)) + + if arg_ctx_overlap: + if arg_ctx_overlap == red_overlap: + # All variables are shadowed by context, that's OK. + return super(_MapDomainMapper, self).map_reduction( + expr, expn_state) + else: + raise LoopyError("reduction '%s' has" + "some of the reduction variables affected " + "by the map_domain shadowed by context. " + "Either all or none must be shadowed." + % str(expr)) + + new_inames = list(expr.inames) + for old_iname in self.old_inames: + new_inames.remove(old_iname) + new_inames.extend(self.new_inames) + + from loopy.symbolic import Reduction + return Reduction(expr.operation, tuple(new_inames), + self.rec(expr.expr, expn_state), + expr.allow_simultaneous) + else: + return super(_MapDomainMapper, self).map_reduction(expr, expn_state) + + def map_variable(self, expr, expn_state): + if (expr.name in self.old_inames + and expr.name not in expn_state.arg_context + and self.within( + expn_state.kernel, + expn_state.instruction)): + return self.substitutions[expr.name] + else: + return super(_MapDomainMapper, self).map_variable(expr, expn_state) + + +def _find_aff_subst_from_map(iname, isl_map): + if not isinstance(isl_map, isl.BasicMap): + raise RuntimeError("isl_map must be a BasicMap") + + dt, dim_idx = isl_map.get_var_dict()[iname] + + assert dt == dim_type.in_ + + # Force isl to solve for only this iname on its side of the map, by + # projecting out all other "in" variables. + isl_map = isl_map.project_out(dt, dim_idx+1, isl_map.dim(dt)-(dim_idx+1)) + isl_map = isl_map.project_out(dt, 0, dim_idx) + dim_idx = 0 + + # Convert map to set to avoid "domain of affine expression should be a set". + # The old "in" variable will be the last of the out_dims. + new_dim_idx = isl_map.dim(dim_type.out) + isl_map = isl_map.move_dims( + dim_type.out, isl_map.dim(dim_type.out), + dt, dim_idx, 1) + isl_map = isl_map.range() # now a set + dt = dim_type.set + dim_idx = new_dim_idx + del new_dim_idx + + for cns in isl_map.get_constraints(): + if cns.is_equality() and cns.involves_dims(dt, dim_idx, 1): + coeff = cns.get_coefficient_val(dt, dim_idx) + cns_zeroed = cns.set_coefficient_val(dt, dim_idx, 0) + if cns_zeroed.involves_dims(dt, dim_idx, 1): + # not suitable, constraint still involves dim, perhaps in a div + continue + + if coeff.is_one(): + return -cns_zeroed.get_aff() + elif coeff.is_negone(): + return cns_zeroed.get_aff() + else: + # not suitable, coefficient does not have unit coefficient + continue + + raise LoopyError("no suitable equation for '%s' found" % iname) + + +def map_domain(kernel, isl_map, within=None): + # FIXME: Express _split_iname_backend in terms of this + # Missing/deleted for now: + # - slab processing + # - priorities processing + # FIXME: Process priorities + # FIXME: Express affine_map_inames in terms of this, deprecate + # FIXME: Document + + # FIXME: Support within + + # {{{ within processing (disabled for now) + if within is not None: + raise NotImplementedError("within") + + from loopy.match import parse_match + within = parse_match(within) + + # {{{ return the same kernel if no kernel matches + + def _do_not_transform_if_no_within_matches(): + for insn in kernel.instructions: + if within(kernel, insn): + return + + return kernel + + _do_not_transform_if_no_within_matches() + + # }}} + + # }}} + + if not isl_map.is_bijective(): + raise LoopyError("isl_map must be bijective") + + new_inames = frozenset(isl_map.get_var_dict(dim_type.out)) + old_inames = frozenset(isl_map.get_var_dict(dim_type.in_)) + + # {{{ solve for representation of old inames in terms of new + + substitutions = {} + var_substitutions = {} + applied_iname_rewrites = kernel.applied_iname_rewrites[:] + + from loopy.symbolic import aff_to_expr + from pymbolic import var + for iname in old_inames: + substitutions[iname] = aff_to_expr( + _find_aff_subst_from_map(iname, isl_map)) + var_substitutions[var(iname)] = aff_to_expr( + _find_aff_subst_from_map(iname, isl_map)) + + applied_iname_rewrites.append(var_substitutions) + del var_substitutions + + # }}} + + def process_set(s): + var_dict = s.get_var_dict() + + overlap = old_inames & frozenset(var_dict) + if overlap and len(overlap) != len(old_inames): + raise LoopyError("loop domain '%s' involves a part " + "of the map domain inames. Domains must " + "either involve all or none of the map domain " + "inames." % s) + + # {{{ align dims of isl_map and s + + # FIXME: Make this less gross + # FIXME: Make an exported/documented interface of this in islpy + from islpy import _align_dim_type + + map_with_s_domain = isl.Map.from_domain(s) + + dim_types = [dim_type.param, dim_type.in_, dim_type.out] + s_names = [ + map_with_s_domain.get_dim_name(dt, i) + for dt in dim_types + for i in range(map_with_s_domain.dim(dt)) + ] + map_names = [ + isl_map.get_dim_name(dt, i) + for dt in dim_types + for i in range(isl_map.dim(dt)) + ] + aligned_map = _align_dim_type( + dim_type.param, + isl_map, map_with_s_domain, False, + map_names, s_names) + aligned_map = _align_dim_type( + dim_type.in_, + isl_map, map_with_s_domain, False, + map_names, s_names) + # Old code + """ + aligned_map = _align_dim_type( + dim_type.param, + isl_map, map_with_s_domain, obj_bigger_ok=False, + obj_names=map_names, tgt_names=s_names) + aligned_map = _align_dim_type( + dim_type.in_, + isl_map, map_with_s_domain, obj_bigger_ok=False, + obj_names=map_names, tgt_names=s_names) + """ + # }}} + + return aligned_map.intersect_domain(s).range() + + # FIXME: Revive _project_out_only_if_all_instructions_in_within + + new_domains = [process_set(dom) for dom in kernel.domains] + + # {{{ update within_inames + + new_insns = [] + for insn in kernel.instructions: + overlap = old_inames & insn.within_inames + if overlap and within(kernel, insn): + if len(overlap) != len(old_inames): + raise LoopyError("instruction '%s' is within only a part " + "of the map domain inames. Instructions must " + "either be within all or none of the map domain " + "inames." % insn.id) + + insn = insn.copy( + within_inames=(insn.within_inames - old_inames) | new_inames) + else: + # leave insn unmodified + pass + + new_insns.append(insn) + + # }}} + + kernel = kernel.copy( + domains=new_domains, + instructions=new_insns, + applied_iname_rewrites=applied_iname_rewrites) + + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + ins = _MapDomainMapper(rule_mapping_context, within, + new_inames, substitutions) + + kernel = ins.map_kernel(kernel) + kernel = rule_mapping_context.finish_kernel(kernel) + + return kernel + +# }}} + + def add_inames_for_unused_hw_axes(kernel, within=None): """ Returns a kernel with inames added to each instruction diff --git a/test/test_transform.py b/test/test_transform.py index daa659808d1e7aa12f51d7b4b897672aa3344874..4f0b11ddb5d382fe84a50eb1ba395d5c54fc6585 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -550,7 +550,6 @@ def test_split_iname_only_if_in_within(): def test_nested_substs_in_insns(ctx_factory): ctx = ctx_factory() - import loopy as lp ref_knl = lp.make_kernel( "{[i]: 0<=i<10}", @@ -568,6 +567,59 @@ def test_nested_substs_in_insns(ctx_factory): lp.auto_test_vs_ref(ref_knl, ctx, knl) +def test_diamond_tiling(ctx_factory, interactive=False): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + ref_knl = lp.make_kernel( + "[nx,nt] -> {[ix, it]: 1<=ix {[ix, it] -> [tx, tt, tparity, itt, itx]: " + "16*(tx - tt) + itx - itt = ix - it and " + "16*(tx + tt + tparity) + itt + itx = ix + it and " + "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}") + knl = lp.map_domain(knl_for_transform, m) + knl = lp.prioritize_loops(knl, "tt,tparity,tx,itt,itx") + + if interactive: + nx = 43 + u = np.zeros((nx, 200)) + x = np.linspace(-1, 1, nx) + dx = x[1] - x[0] + u[:, 0] = u[:, 1] = np.exp(-100*x**2) + + u_dev = cl.array.to_device(queue, u) + knl(queue, u=u_dev, dx=dx, dt=dx) + + u = u_dev.get() + import matplotlib.pyplot as plt + plt.imshow(u.T) + plt.show() + else: + types = {"dt,dx,u": np.float64} + knl = lp.add_and_infer_dtypes(knl, types) + ref_knl = lp.add_and_infer_dtypes(ref_knl, types) + + lp.auto_test_vs_ref(ref_knl, ctx, knl, + parameters={ + "nx": 200, "nt": 300, + "dx": 1, "dt": 1 + }) + + def test_extract_subst_with_iname_deps_in_templ(ctx_factory): knl = lp.make_kernel( "{[i, j, k]: 0<=i<100 and 0<=j,k<5}",