From eb18c8dbc0df14be2e123e1e27caf982fa80f5d7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 5 Nov 2012 02:47:41 -0500 Subject: [PATCH] Properly deal with term rewriting in the presence of subst rules. Most tests pass. --- MEMO | 19 +- loopy/__init__.py | 135 +++++++------- loopy/context_matching.py | 193 ++++++++++++++++++++ loopy/creation.py | 181 ++----------------- loopy/cse.py | 365 +++++++++++++++++++------------------- loopy/kernel.py | 129 +++++--------- loopy/preprocess.py | 22 ++- loopy/subst.py | 30 +--- loopy/symbolic.py | 356 ++++++++++++++++++++----------------- test/test_linalg.py | 1 - test/test_loopy.py | 56 +++++- 11 files changed, 794 insertions(+), 693 deletions(-) create mode 100644 loopy/context_matching.py diff --git a/MEMO b/MEMO index 5d69b66be..23865403a 100644 --- a/MEMO +++ b/MEMO @@ -41,15 +41,23 @@ Things to consider - Dependency on non-local global writes is ill-formed +- No substitution rules allowed on lhs of insns + To-do ^^^^^ -- Prohibit known variable names as subst rule arguments - - Expose iname-duplicate-and-rename as a primitive. - Kernel fusion +- ExpandingIdentityMapper + extract_subst -> needs WalkMapper + duplicate_inames + join_inames + padding + split_iname [DONE] + CSE [DONE] + - Data implementation tags TODO initial bringup: - implemented_arg_info @@ -62,11 +70,18 @@ To-do - vectorization - automatic copies - write_image() + - change_arg_to_image (test!) + +- Import SEM test - Make tests run on GPUs Fixes: +- applied_iname_rewrites tracking for prefetch footprints isn't bulletproof + old inames may still be around, so the rewrite may or may not have to be + applied. + - Group instructions by dependency/inames for scheduling, to increase sched. scalability diff --git a/loopy/__init__.py b/loopy/__init__.py index 2f369e749..e06ffdbe3 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -39,6 +39,8 @@ from islpy import dim_type from pytools import MovedFunctionDeprecationWrapper +from loopy.symbolic import ExpandingIdentityMapper + @@ -93,10 +95,46 @@ class infer_type: # {{{ split inames +class _InameSplitter(ExpandingIdentityMapper): + def __init__(self, kernel, within, + split_iname, outer_iname, inner_iname, replacement_index): + ExpandingIdentityMapper.__init__(self, + kernel.substitutions, kernel.get_var_name_generator()) + + self.within = within + + self.split_iname = split_iname + self.outer_iname = outer_iname + self.inner_iname = inner_iname + + self.replacement_index = replacement_index + + def map_reduction(self, expr, expn_state): + if self.split_iname in expr.inames and self.within(expn_state.stack): + new_inames = list(expr.inames) + new_inames.remove(self.split_iname) + new_inames.extend([self.outer_iname, self.inner_iname]) + + from loopy.symbolic import Reduction + return Reduction(expr.operation, tuple(new_inames), + self.rec(expr.expr, expn_state)) + else: + return ExpandingIdentityMapper.map_reduction(self, expr, expn_state) + + def map_variable(self, expr, expn_state): + if expr.name == self.split_iname and self.within(expn_state.stack): + return self.replacement_index + else: + return ExpandingIdentityMapper.map_variable(self, expr, expn_state) + def split_iname(kernel, split_iname, inner_length, outer_iname=None, inner_iname=None, outer_tag=None, inner_tag=None, - slabs=(0, 0), do_tagged_check=True): + slabs=(0, 0), do_tagged_check=True, + within=None): + + from loopy.context_matching import parse_stack_match + within = parse_stack_match(within) existing_tag = kernel.iname_to_tag.get(split_iname) from loopy.kernel import ForceSequentialTag @@ -154,21 +192,16 @@ def split_iname(kernel, split_iname, inner_length, outer = var(outer_iname) new_loop_index = inner + outer*inner_length + subst_map = {var(split_iname): new_loop_index} + applied_iname_rewrites.append(subst_map) + # {{{ actually modify instructions - from loopy.symbolic import ReductionLoopSplitter + ins = _InameSplitter(kernel, within, + split_iname, outer_iname, inner_iname, new_loop_index) - rls = ReductionLoopSplitter(split_iname, outer_iname, inner_iname) new_insns = [] for insn in kernel.instructions: - subst_map = {var(split_iname): new_loop_index} - applied_iname_rewrites.append(subst_map) - - from loopy.symbolic import SubstitutionMapper - subst_mapper = SubstitutionMapper(subst_map.get) - - new_expr = subst_mapper(rls(insn.expression)) - if split_iname in insn.forced_iname_deps: new_forced_iname_deps = ( (insn.forced_iname_deps.copy() @@ -178,8 +211,8 @@ def split_iname(kernel, split_iname, inner_length, new_forced_iname_deps = insn.forced_iname_deps insn = insn.copy( - assignee=subst_mapper(insn.assignee), - expression=new_expr, + assignee=ins(insn.assignee, insn.id), + expression=ins(insn.expression, insn.id), forced_iname_deps=new_forced_iname_deps) new_insns.append(insn) @@ -188,10 +221,11 @@ def split_iname(kernel, split_iname, inner_length, iname_slab_increments = kernel.iname_slab_increments.copy() iname_slab_increments[outer_iname] = slabs + result = (kernel - .map_expressions(subst_mapper, exclude_instructions=True) .copy(domains=new_domains, iname_slab_increments=iname_slab_increments, + substitutions=ins.get_new_substitutions(), instructions=new_insns, applied_iname_rewrites=applied_iname_rewrites, )) @@ -382,9 +416,11 @@ def _add_kernel_axis(kernel, axis_name, start, stop, base_inames): return kernel.copy(domains=domch.get_domains_with(domain)) def _process_footprint_subscripts(kernel, rule_name, sweep_inames, - footprint_subscripts, arg, newly_created_vars): + footprint_subscripts, arg): """Track applied iname rewrites, deal with slice specifiers ':'.""" + name_gen = kernel.get_var_name_generator() + from pymbolic.primitives import Variable if footprint_subscripts is None: @@ -423,11 +459,9 @@ def _process_footprint_subscripts(kernel, rule_name, sweep_inames, raise NotImplementedError("add_prefetch only " "supports full slices") - axis_name = kernel.make_unique_var_name( - based_on="%s_fetch_axis_%d" % (arg.name, axis_nr), - extra_used_vars=newly_created_vars) + axis_name = name_gen( + based_on="%s_fetch_axis_%d" % (arg.name, axis_nr)) - newly_created_vars.add(axis_name) kernel = _add_kernel_axis(kernel, axis_name, 0, arg.shape[axis_nr], frozenset(sweep_inames) | fsub_dependencies) sweep_inames = sweep_inames + [axis_name] @@ -537,11 +571,11 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None, kernel, subst_use, sweep_inames, inames_to_be_removed = \ _process_footprint_subscripts( kernel, rule_name, sweep_inames, - footprint_subscripts, arg, newly_created_vars) + footprint_subscripts, arg) - new_kernel = precompute(kernel, subst_use, arg.dtype, sweep_inames, + new_kernel = precompute(kernel, subst_use, sweep_inames, new_storage_axis_names=dim_arg_names, - default_tag=default_tag) + default_tag=default_tag, dtype=arg.dtype) # {{{ remove inames that were temporarily added by slice sweeps @@ -571,49 +605,19 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None, # {{{ instruction processing -class _IdMatch(object): - def __init__(self, value): - self.value = value - -class _ExactIdMatch(_IdMatch): - def __call__(self, insn): - return insn.id == self.value - -class _ReIdMatch: - def __call__(self, insn): - return self.value.match(insn.id) is not None - -def _parse_insn_match(insn_match): - import re - colon_idx = insn_match.find(":") - if colon_idx == -1: - return _ExactIdMatch(insn_match) - - match_tp = insn_match[:colon_idx] - match_val = insn_match[colon_idx+1:] - - if match_tp == "glob": - from fnmatch import translate - return _ReIdMatch(re.compile(translate(match_val))) - elif match_tp == "re": - return _ReIdMatch(re.compile(match_val)) - else: - raise ValueError("match type '%s' not understood" % match_tp) - - - - def find_instructions(kernel, insn_match): - match = _parse_insn_match(insn_match) - return [insn for insn in kernel.instructions if match(insn)] + from loopy.context_matching import parse_id_match + match = parse_id_match(insn_match) + return [insn for insn in kernel.instructions if match(insn.id, None)] def map_instructions(kernel, insn_match, f): - match = _parse_insn_match(insn_match) + from loopy.context_matching import parse_id_match + match = parse_id_match(insn_match) new_insns = [] for insn in kernel.instructions: - if match(insn): + if match(insn.id, None): new_insns.append(f(insn)) else: new_insns.append(insn) @@ -623,8 +627,8 @@ def map_instructions(kernel, insn_match, f): def set_instruction_priority(kernel, insn_match, priority): """Set the priority of instructions matching *insn_match* to *priority*. - *insn_match* may be an instruction id, a regular expression prefixed by `re:`, - or a file-name-style glob prefixed by `glob:`. + *insn_match* may be any instruction id match understood by + :func:`loopy.context_matching.parse_id_match`. """ def set_prio(insn): return insn.copy(priority=priority) @@ -634,8 +638,8 @@ def add_dependency(kernel, insn_match, dependency): """Add the instruction dependency *dependency* to the instructions matched by *insn_match*. - *insn_match* may be an instruction id, a regular expression prefixed by `re:`, - or a file-name-style glob prefixed by `glob:`. + *insn_match* may be any instruction id match understood by + :func:`loopy.context_matching.parse_id_match`. """ def add_dep(insn): return insn.copy(insn_deps=insn.insn_deps + [dependency]) @@ -659,6 +663,13 @@ def change_arg_to_image(knl, name): # }}} +# {{{ duplicate inames + +def duplicate_inames(knl, inames): + pass + +# }}} + diff --git a/loopy/context_matching.py b/loopy/context_matching.py new file mode 100644 index 000000000..51cc8a5fd --- /dev/null +++ b/loopy/context_matching.py @@ -0,0 +1,193 @@ +"""Matching functionality for instruction ids and subsitution +rule invocations stacks.""" + +from __future__ import division + +__copyright__ = "Copyright (C) 2012 Andreas Kloeckner" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + + + +# {{{ id match objects + +class AllMatch(object): + def __call__(self, identifier, tag): + return True + +class RegexIdentifierMatch(object): + def __init__(self, id_re, tag_re=None): + self.id_re = id_re + self.tag_re = tag_re + + def __call__(self, identifier, tag): + if self.tag_re is None: + return self.id_re.match(identifier) is not None + else: + if tag is None: + tag = "" + + return ( + self.id_re.match(identifier) is not None + and self.tag_re.match(tag) is not None) + +class AlternativeMatch(object): + def __init__(self, matches): + self.matches = matches + + def __call__(self, identifier, tag): + from pytools import any + return any( + mtch(identifier, tag) for mtch in self.matches) + +# }}} + +# {{{ single id match parsing + +def parse_id_match(id_matches): + """Syntax examples: + + my_insn + compute_* + fetch*$first + fetch*$first,store*$first + + Alternatively, a list of *(name_glob, tag_glob)* tuples. + """ + + if id_matches is None: + return AllMatch() + + if isinstance(id_matches, str): + id_matches = id_matches.split(",") + + if len(id_matches) > 1: + return AlternativeMatch(parse_id_match(im) for im in id_matches) + + if len(id_matches) == 0: + return AllMatch() + + id_match, = id_matches + del id_matches + + def re_from_glob(s): + import re + from fnmatch import translate + return re.compile(translate(s.strip())) + + if not isinstance(id_match, tuple): + components = id_match.split("$") + + if len(components) == 1: + return RegexIdentifierMatch(re_from_glob(components[0])) + elif len(components) == 2: + return RegexIdentifierMatch( + re_from_glob(components[0]), + re_from_glob(components[1])) + else: + raise RuntimeError("too many (%d) $-separated components in id match" + % len(components)) + +# }}} + +# {{{ stack match objects + +# these match from the tail of the stack + +class StackMatchBase(object): + pass + +class AllStackMatch(StackMatchBase): + def __call__(self, stack): + return True + +class StackIdMatch(StackMatchBase): + def __init__(self, id_match, up_match): + self.id_match = id_match + self.up_match = up_match + + def __call__(self, stack): + if not stack: + return False + + last = stack[-1] + if not self.id_match(*last): + return False + + if self.up_match is None: + return True + else: + return self.up_match(stack[:-1]) + +class StackWildcardMatch(StackMatchBase): + def __init__(self, up_match): + self.up_match = up_match + + def __call__(self, stack): + if self.up_match is None: + return True + + n = len(stack) + + for i in xrange(n): + if self.up_match(stack[:-i]): + return True + + return False + +# }}} + +# {{{ stack match parsing + +def parse_stack_match(smatch): + """Syntax example:: + + lowest < next < ... < highest + + where `lowest` is necessarily the bottom of the stack. There is currently + no way to anchor to the top of the stack. + """ + + if isinstance(smatch, StackMatchBase): + return smatch + + match = AllStackMatch() + + if smatch is None: + return match + + components = smatch.split("<") + + for comp in components[::-1]: + comp = comp.strip() + if comp == "...": + match = StackWildcardMatch(match) + else: + match = StackIdMatch(parse_id_match(comp), match) + + return match + +# }}} + + + +# vim: foldmethod=marker diff --git a/loopy/creation.py b/loopy/creation.py index 55bc29d44..5c16d24cd 100644 --- a/loopy/creation.py +++ b/loopy/creation.py @@ -247,172 +247,31 @@ def create_temporaries(knl): # }}} -# {{{ reduction iname duplication +# {{{ check for reduction iname duplication -def duplicate_reduction_inames(kernel): +def check_for_reduction_inames_duplication_requests(kernel): # {{{ helper function - newly_created_vars = set() - - def duplicate_reduction_inames(reduction_expr, rec): - child = rec(reduction_expr.expr) - new_red_inames = [] - did_something = False - + def check_reduction_inames(reduction_expr, rec): for iname in reduction_expr.inames: if iname.startswith("@"): - new_iname = kernel.make_unique_var_name(iname[1:]+"_"+name_base, - newly_created_vars) - - old_inames.append(iname.lstrip("@")) - new_inames.append(new_iname) - newly_created_vars.add(new_iname) - new_red_inames.append(new_iname) - did_something = True - else: - new_red_inames.append(iname) - - if did_something: - from loopy.symbolic import SubstitutionMapper - from pymbolic.mapper.substitutor import make_subst_func - from pymbolic import var - - subst_dict = dict( - (old_iname, var(new_iname)) - for old_iname, new_iname in zip( - reduction_expr.untagged_inames, new_red_inames)) - subst_map = SubstitutionMapper(make_subst_func(subst_dict)) - - child = subst_map(child) - - from loopy.symbolic import Reduction - return Reduction( - operation=reduction_expr.operation, - inames=tuple(new_red_inames), - expr=child) + raise RuntimeError("Reduction iname duplication with '@' is no " + "longer supported. Use loopy.duplicate_inames instead.") # }}} - from loopy.symbolic import ReductionCallbackMapper - from loopy.isl_helpers import duplicate_axes - - new_domains = kernel.domains - new_insns = [] - - new_iname_to_tag = kernel.iname_to_tag.copy() + from loopy.symbolic import ReductionCallbackMapper + rcm = ReductionCallbackMapper(check_reduction_inames) for insn in kernel.instructions: - old_inames = [] - new_inames = [] - name_base = insn.id - - new_insns.append(insn.copy( - expression=ReductionCallbackMapper(duplicate_reduction_inames) - (insn.expression))) - - for old, new in zip(old_inames, new_inames): - new_domains = duplicate_axes(new_domains, [old], [new]) - if old in kernel.iname_to_tag: - new_iname_to_tag[new] = kernel.iname_to_tag[old] + rcm(insn.expression) - new_substs = {} for sub_name, sub_rule in kernel.substitutions.iteritems(): - old_inames = [] - new_inames = [] - name_base = sub_name - - new_substs[sub_name] = sub_rule.copy( - expression=ReductionCallbackMapper(duplicate_reduction_inames) - (sub_rule.expression)) - - for old, new in zip(old_inames, new_inames): - new_domains = duplicate_axes(new_domains, [old], [new]) - if old in kernel.iname_to_tag: - new_iname_to_tag[new] = kernel.iname_to_tag[old] - - return kernel.copy( - instructions=new_insns, - substitutions=new_substs, - domains=new_domains, - iname_to_tag=new_iname_to_tag) + rcm(sub_rule.expression) # }}} -# {{{ duplicate inames - -def duplicate_inames(knl): - new_insns = [] - new_domains = knl.domains - new_iname_to_tag = knl.iname_to_tag.copy() - - newly_created_vars = set() - - for insn in knl.instructions: - if insn.duplicate_inames_and_tags: - insn_dup_iname_to_tag = dict(insn.duplicate_inames_and_tags) - - if not set(insn_dup_iname_to_tag.keys()) <= knl.all_inames(): - raise ValueError("In instruction '%s': " - "cannot duplicate inames '%s'--" - "they don't exist" % ( - insn.id, - ",".join( - set(insn_dup_iname_to_tag.keys())-knl.all_inames()))) - - # {{{ duplicate non-reduction inames - - reduction_inames = insn.reduction_inames() - - inames_to_duplicate = [iname - for iname, tag in insn.duplicate_inames_and_tags - if iname not in reduction_inames] - - new_inames = [ - knl.make_unique_var_name( - based_on=iname+"_"+insn.id, - extra_used_vars=newly_created_vars) - for iname in inames_to_duplicate] - - for old_iname, new_iname in zip(inames_to_duplicate, new_inames): - new_tag = insn_dup_iname_to_tag[old_iname] - new_iname_to_tag[new_iname] = new_tag - - newly_created_vars.update(new_inames) - - from loopy.isl_helpers import duplicate_axes - new_domains = duplicate_axes(new_domains, inames_to_duplicate, new_inames) - - from loopy.symbolic import SubstitutionMapper - from pymbolic.mapper.substitutor import make_subst_func - from pymbolic import var - old_to_new = dict( - (old_iname, var(new_iname)) - for old_iname, new_iname in zip(inames_to_duplicate, new_inames)) - subst_map = SubstitutionMapper(make_subst_func(old_to_new)) - new_expression = subst_map(insn.expression) - - # }}} - - if len(inames_to_duplicate) < len(insn.duplicate_inames_and_tags): - raise RuntimeError("cannot use [|...] syntax to rename reduction " - "inames") - - insn = insn.copy( - assignee=subst_map(insn.assignee), - expression=new_expression, - forced_iname_deps=set( - old_to_new.get(iname, iname) for iname in insn.forced_iname_deps), - duplicate_inames_and_tags=[]) - - new_insns.append(insn) - - return knl.copy( - instructions=new_insns, - domains=new_domains, - iname_to_tag=new_iname_to_tag) -# }}} - # {{{ kernel creation top-level def make_kernel(*args, **kwargs): @@ -430,30 +289,11 @@ def make_kernel(*args, **kwargs): iname_to_tag_requests=[]) check_for_nonexistent_iname_deps(knl) + check_for_reduction_inames_duplication_requests(knl) - knl = duplicate_reduction_inames(knl) - - # ------------------------------------------------------------------------- - # Ordering dependency: - # ------------------------------------------------------------------------- - # Must duplicate reduction inames before tagging reduction inames as - # sequential because otherwise the latter operation will run into @iname - # (i.e. duplication) markers and not understand them. - # ------------------------------------------------------------------------- knl = tag_reduction_inames_as_sequential(knl) - knl = create_temporaries(knl) - knl = duplicate_inames(knl) - - # ------------------------------------------------------------------------- - # Ordering dependency: - # ------------------------------------------------------------------------- - # Must duplicate inames before expanding CSEs, otherwise inames within the - # scope of duplication might be CSE'd out to a different instruction and - # never be found by duplication. - # ------------------------------------------------------------------------- - knl = expand_cses(knl) # ------------------------------------------------------------------------- @@ -462,6 +302,7 @@ def make_kernel(*args, **kwargs): # Must create temporary before checking for writes to temporary variables # that are domain parameters. # ------------------------------------------------------------------------- + check_for_multiple_writes_to_loop_bounds(knl) check_for_duplicate_names(knl) check_written_variable_names(knl) diff --git a/loopy/cse.py b/loopy/cse.py index e5edaa10c..48ef6e356 100644 --- a/loopy/cse.py +++ b/loopy/cse.py @@ -27,7 +27,8 @@ THE SOFTWARE. import islpy as isl from islpy import dim_type -from loopy.symbolic import get_dependencies, SubstitutionMapper +from loopy.symbolic import (get_dependencies, SubstitutionMapper, + ExpandingIdentityMapper) from pymbolic.mapper.substitutor import make_subst_func import numpy as np @@ -39,16 +40,13 @@ from pymbolic import var class InvocationDescriptor(Record): __slots__ = [ - "expr", "args", "expands_footprint", "is_in_footprint", - # Record from which substitution rule this invocation of the rule - # being precomputed originated. If all invocations end up being - # in-footprint, then the replacement with the prefetch can be made - # within the rule. - "from_subst_rule" + # Remember where the invocation happened, in terms of the expansion + # call stack. + "expansion_stack", ] @@ -379,9 +377,161 @@ def simplify_via_aff(expr): -def precompute(kernel, subst_use, dtype, sweep_inames=[], +class InvocationGatherer(ExpandingIdentityMapper): + def __init__(self, kernel, subst_name, subst_tag, within): + ExpandingIdentityMapper.__init__(self, + kernel.substitutions, kernel.get_var_name_generator()) + + from loopy.symbolic import ParametrizedSubstitutor + self.subst_expander = ParametrizedSubstitutor( + kernel.substitutions) + + self.kernel = kernel + self.subst_name = subst_name + self.subst_tag = subst_tag + self.within = within + + self.invocation_descriptors = [] + + def map_substitution(self, name, tag, arguments, expn_state): + process_me = name == self.subst_name + + if self.subst_tag is not None and self.subst_tag != tag: + process_me = False + + process_me = process_me and self.within(expn_state.stack) + + if not process_me: + return ExpandingIdentityMapper.map_substitution( + self, name, tag, arguments, expn_state) + + rule = self.old_subst_rules[name] + arg_context = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context) + + arg_deps = set() + for arg_val in arg_context.itervalues(): + arg_deps = (arg_deps + | get_dependencies(self.subst_expander(arg_val, insn_id=None))) + + if not arg_deps <= self.kernel.all_inames(): + from warnings import warn + warn("Precompute arguments in '%s(%s)' do not consist exclusively " + "of inames and constants--specifically, these are " + "not inames: %s. Ignoring." % ( + name, + ", ".join(str(arg) for arg in arguments), + ", ".join(arg_deps - self.kernel.all_inames()), + )) + + return ExpandingIdentityMapper.map_substitution( + self, name, tag, arguments, expn_state) + + self.invocation_descriptors.append( + InvocationDescriptor( + args=[arg_context[arg_name] for arg_name in rule.arguments], + expansion_stack=expn_state.stack)) + + return 0 # exact value irrelevant + + + + +class InvocationReplacer(ExpandingIdentityMapper): + def __init__(self, kernel, subst_name, subst_tag, within, + invocation_descriptors, + storage_axis_names, storage_axis_sources, + storage_base_indices, non1_storage_axis_names, + target_var_name): + ExpandingIdentityMapper.__init__(self, + kernel.substitutions, kernel.get_var_name_generator()) + + from loopy.symbolic import ParametrizedSubstitutor + self.subst_expander = ParametrizedSubstitutor( + kernel.substitutions, kernel.get_var_name_generator()) + + self.kernel = kernel + self.subst_name = subst_name + self.subst_tag = subst_tag + self.within = within + + self.invocation_descriptors = invocation_descriptors + + self.storage_axis_names = storage_axis_names + self.storage_axis_sources = storage_axis_sources + self.storage_base_indices = storage_base_indices + self.non1_storage_axis_names = non1_storage_axis_names + + self.target_var_name = target_var_name + + def map_substitution(self, name, tag, arguments, expn_state): + process_me = name == self.subst_name + + if self.subst_tag is not None and self.subst_tag != tag: + process_me = False + + process_me = process_me and self.within(expn_state.stack) + + # {{{ find matching invocation descriptor + + rule = self.old_subst_rules[name] + arg_context = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context) + args = [arg_context[arg_name] for arg_name in rule.arguments] + + if not process_me: + return ExpandingIdentityMapper.map_substitution( + self, name, tag, arguments, expn_state) + + matching_invdesc = None + for invdesc in self.invocation_descriptors: + if invdesc.args == args and expn_state.stack: + # Could be more than one, that's fine. + matching_invdesc = invdesc + break + + assert matching_invdesc is not None + + invdesc = matching_invdesc + del matching_invdesc + + # }}} + + if not invdesc.is_in_footprint: + return ExpandingIdentityMapper.map_substitution( + self, name, tag, arguments, expn_state) + + assert len(arguments) == len(rule.arguments) + + stor_subscript = [] + for sax_name, sax_source, sax_base_idx in zip( + self.storage_axis_names, + self.storage_axis_sources, + self.storage_base_indices): + if sax_name not in self.non1_storage_axis_names: + continue + + if isinstance(sax_source, int): + # an argument + ax_index = arguments[sax_source] + else: + # an iname + ax_index = var(sax_source) + + ax_index = simplify_via_aff(ax_index - sax_base_idx) + stor_subscript.append(ax_index) + + new_outer_expr = var(self.target_var_name) + if stor_subscript: + new_outer_expr = new_outer_expr[tuple(stor_subscript)] + + return new_outer_expr + # can't possibly be nested, don't recurse + + +def precompute(kernel, subst_use, sweep_inames=[], within=None, storage_axes=None, new_storage_axis_names=None, storage_axis_to_tag={}, - default_tag="l.auto"): + default_tag="l.auto", dtype=None): """Precompute the expression described in the substitution rule determined by *subst_use* and store it in a temporary array. A precomputation needs two things to operate, a list of *sweep_inames* (order irrelevant) and an @@ -426,6 +576,7 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[], :arg sweep_inames: A :class:`list` of inames and/or rule argument names to be swept. :arg storage_axes: A :class:`list` of inames and/or rule argument names/indices to be used as storage axes. + :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. If `storage_axes` is not specified, it defaults to the arrangement `<direct sweep axes><arguments>` with the direct sweep axes being the @@ -486,6 +637,13 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[], raise ValueError("not all uses in subst_use agree " "on rule name and tag") + from loopy.context_matching import parse_stack_match + within = parse_stack_match(within) + + from loopy import infer_type + if dtype is None: + dtype = infer_type + # }}} # {{{ process invocations in footprint generators, start invocation_descriptors @@ -504,9 +662,9 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[], "be substitution rule invocation") invocation_descriptors.append( - InvocationDescriptor(expr=fpg, args=args, + InvocationDescriptor(args=args, expands_footprint=True, - from_subst_rule=None)) + )) # }}} @@ -520,63 +678,14 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[], # {{{ gather up invocations in kernel code, finish invocation_descriptors - current_subst_rule_stack = [] - - # We need to work on the fully expanded form of an expression. - # To that end, instantiate a substitutor. - from loopy.symbolic import ParametrizedSubstitutor - rules_except_mine = kernel.substitutions.copy() - del rules_except_mine[subst_name] - subst_expander = ParametrizedSubstitutor(rules_except_mine, - one_level=True) - - def gather_substs(expr, name, tag, args, rec): - if subst_name != name: - if name in subst_expander.rules: - # We can't deal with invocations that involve other substitution's - # arguments. Therefore, fully expand each encountered substitution - # rule and look at the invocations of subst_name occurring in its - # body. - - expanded_expr = subst_expander(expr) - current_subst_rule_stack.append(name) - result = rec(expanded_expr) - current_subst_rule_stack.pop() - return result - - else: - return None - - if subst_tag is not None and subst_tag != tag: - # use fall-back identity mapper - return None - - if len(args) != len(subst.arguments): - raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)" - % (subst_name, len(args), len(subst.arguments), )) - - arg_deps = get_dependencies(args) - if not arg_deps <= kernel.all_inames(): - raise RuntimeError("CSE arguments in '%s' do not consist " - "exclusively of inames" % expr) + invg = InvocationGatherer(kernel, subst_name, subst_tag, within) - if current_subst_rule_stack: - current_subst_rule = current_subst_rule_stack[-1] - else: - current_subst_rule = None + for insn in kernel.instructions: + invg(insn.expression, insn.id) + for invdesc in invg.invocation_descriptors: invocation_descriptors.append( - InvocationDescriptor(expr=expr, args=args, - expands_footprint=footprint_generators is None, - from_subst_rule=current_subst_rule)) - - return expr - - from loopy.symbolic import SubstitutionCallbackMapper - scm = SubstitutionCallbackMapper(names_filter=None, func=gather_substs) - - for insn in kernel.instructions: - scm(insn.expression) + invdesc.copy(expands_footprint=footprint_generators is None)) if not invocation_descriptors: raise RuntimeError("no invocations of '%s' found" % subst_name) @@ -608,7 +717,8 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[], from loopy.symbolic import ParametrizedSubstitutor submap = ParametrizedSubstitutor(kernel.substitutions) - value_inames = get_dependencies(submap(subst.expression)) & kernel.all_inames() + value_inames = get_dependencies( + submap(subst.expression, insn_id=None)) & kernel.all_inames() if value_inames - expanding_usage_arg_deps < extra_storage_axes: raise RuntimeError("unreferenced sweep inames specified: " + ", ".join(extra_storage_axes - value_inames - expanding_usage_arg_deps)) @@ -736,121 +846,13 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[], # {{{ substitute rule into expressions in kernel (if within footprint) - left_unused_subst_rule_invocations = [False] - - def do_substs(expr, name, tag, args, rec): - if tag != subst_tag: - left_unused_subst_rule_invocations[0] = True - return expr - - # {{{ check if current use is in-footprint - - if current_subst_rule is None: - # The current subsitution was *not* found inside another - # substitution rule. Try and dig up the corresponding invocation - # descriptor. - - found = False - for invdesc in invocation_descriptors: - if expr == invdesc.expr: - found = True - break - - if footprint_generators is None: - # We only have a right to find the expression if the - # invocation descriptors if they were generated by a scan - # of the code in the first place. If the user gave us - # the footprint generators, that isn't true. - - assert found, expr - - if not found or not invdesc.is_in_footprint: - left_unused_subst_rule_invocations[0] = True - return expr - - else: - # The current subsitution *was* found inside another substitution - # rule. We can't dig up the corresponding invocation descriptor, - # because it was the result of expanding that outer substitution - # rule. But we do know what the current outer substitution rule is, - # and we can check if all uses within that rule were uniformly - # in-footprint. If so, we'll go ahead, otherwise we'll bomb out. - - current_rule_invdescs_in_footprint = [ - invdesc.is_in_footprint - for invdesc in invocation_descriptors - if invdesc.from_subst_rule == current_subst_rule] - - from pytools import all - all_in = all(current_rule_invdescs_in_footprint) - all_out = all(not b for b in current_rule_invdescs_in_footprint) - - assert not (all_in and all_out) - - if not (all_in or all_out): - raise RuntimeError("substitution '%s' (being precomputed) is used " - "from within substitution '%s', but not all uses of " - "'%s' within '%s' " - "are uniformly within-footprint or outside of the footprint, " - "making a unique replacement of '%s' impossible. Please expand " - "'%s' and try again." - % (subst_name, current_subst_rule, - subst_name, current_subst_rule, - subst_name, current_subst_rule)) - - if all_out: - left_unused_subst_rule_invocations[0] = True - return expr - - assert all_in - - # }}} - - if len(args) != len(subst.arguments): - raise ValueError("invocation of '%s' with too few arguments" - % name) - - stor_subscript = [] - for sax_name, sax_source, sax_base_idx in zip( - storage_axis_names, storage_axis_sources, storage_base_indices): - if sax_name not in non1_storage_axis_names: - continue - - if isinstance(sax_source, int): - # an argument - ax_index = args[sax_source] - else: - # an iname - ax_index = var(sax_source) - - ax_index = simplify_via_aff(ax_index - sax_base_idx) - stor_subscript.append(ax_index) - - new_outer_expr = var(target_var_name) - if stor_subscript: - new_outer_expr = new_outer_expr[tuple(stor_subscript)] - - return new_outer_expr - # can't possibly be nested, don't recurse - - new_insns = [compute_insn] - - current_subst_rule = None - sub_map = SubstitutionCallbackMapper([subst_name], do_substs) - for insn in kernel.instructions: - new_insn = insn.copy(expression=sub_map(insn.expression)) - new_insns.append(new_insn) - - # also catch uses of our rule in other substitution rules - new_substs = {} - for s in kernel.substitutions.itervalues(): - current_subst_rule = s.name - new_substs[s.name] = s.copy( - expression=sub_map(s.expression)) + invr = InvocationReplacer(kernel, subst_name, subst_tag, within, + invocation_descriptors, + storage_axis_names, storage_axis_sources, + storage_base_indices, non1_storage_axis_names, + target_var_name) - # If the subst above caught all uses of the subst rule, get rid of it. - if not left_unused_subst_rule_invocations[0]: - del new_substs[subst_name] + kernel = invr.map_kernel(kernel) # }}} @@ -872,8 +874,7 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[], result = kernel.copy( domains=domch.get_domains_with(new_domain), - instructions=new_insns, - substitutions=new_substs, + instructions=[compute_insn] + kernel.instructions, temporary_variables=new_temporary_variables) from loopy import tag_inames diff --git a/loopy/kernel.py b/loopy/kernel.py index d9fd7c274..014a4f6e3 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -308,6 +308,8 @@ class SubstitutionRule(Record): """ def __init__(self, name, arguments, expression): + assert isinstance(arguments, tuple) + Record.__init__(self, name=name, arguments=arguments, expression=expression) @@ -343,15 +345,12 @@ class Instruction(Record): :ivar temp_var_type: if not None, a type that will be assigned to the new temporary variable created from the assignee - :ivar duplicate_inames_and_tags: a list of inames used in the instruction that will be duplicated onto - different inames. """ def __init__(self, id, assignee, expression, forced_iname_deps=frozenset(), insn_deps=set(), boostable=None, boostable_into=None, - temp_var_type=None, duplicate_inames_and_tags=[], - priority=0): + temp_var_type=None, priority=0): from loopy.symbolic import parse if isinstance(assignee, str): @@ -368,14 +367,13 @@ class Instruction(Record): insn_deps=insn_deps, boostable=boostable, boostable_into=boostable_into, temp_var_type=temp_var_type, - duplicate_inames_and_tags=duplicate_inames_and_tags, priority=priority) @memoize_method def reduction_inames(self): def map_reduction(expr, rec): rec(expr.expr) - for iname in expr.untagged_inames: + for iname in expr.inames: result.add(iname) from loopy.symbolic import ReductionCallbackMapper @@ -631,6 +629,29 @@ def _generate_unique_possibilities(prefix): yield "%s_%d" % (prefix, try_num) try_num += 1 +class _UniqueNameGenerator: + def __init__(self, existing_names): + self.existing_names = existing_names.copy() + + def is_name_conflicting(self, name): + return name in self.existing_names + + def add_name(self, name): + assert name not in self.existing_names + self.existing_names.add(name) + + def add_names(self, names): + assert not frozenset(names) & self.existing_names + self.existing_names.update(names) + + def __call__(self, based_on="var"): + for var_name in _generate_unique_possibilities(based_on): + if not self.is_name_conflicting(var_name): + break + + self.existing_names.add(var_name) + return var_name + _IDENTIFIER_RE = re.compile(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\b") def _gather_identifiers(s): @@ -783,13 +804,7 @@ class LoopKernel(Record): # {{{ parse instructions - INAME_ENTRY_RE = re.compile( - r"^\s*(?P<iname>\w+)\s*(?:\:\s*(?P<tag>[\w.]+))?\s*$") INSN_RE = re.compile( - "\s*(?:\[" - "(?P<iname_deps_and_tags>[\s\w,:.]*)" - "(?:\|(?P<duplicate_inames_and_tags>[\s\w,:.]*))?" - "\])?" "\s*(?:\<(?P<temp_var_type>.*?)\>)?" "\s*(?P<lhs>.+?)\s*(?<!\:)=\s*(?P<rhs>.+?)" "\s*?(?:\{(?P<options>[\s\w=,:]+)\}\s*)?$" @@ -798,32 +813,6 @@ class LoopKernel(Record): r"^\s*(?P<lhs>.+?)\s*:=\s*(?P<rhs>.+)\s*$" ) - def parse_iname_and_tag_list(s): - dup_entries = [ - dep.strip() for dep in s.split(",")] - result = [] - for entry in dup_entries: - if not entry: - continue - - entry_match = INAME_ENTRY_RE.match(entry) - if entry_match is None: - raise RuntimeError( - "could not parse iname:tag entry '%s'" - % entry) - - groups = entry_match.groupdict() - iname = groups["iname"] - assert iname - - tag = None - if groups["tag"] is not None: - tag = parse_tag(groups["tag"]) - - result.append((iname, tag)) - - return result - def parse_insn(insn): insn_match = INSN_RE.match(insn) subst_match = SUBST_RE.match(insn) @@ -870,20 +859,6 @@ class LoopKernel(Record): raise ValueError("unrecognized instruction option '%s'" % opt_key) - if groups["iname_deps_and_tags"] is not None: - inames_and_tags = parse_iname_and_tag_list( - groups["iname_deps_and_tags"]) - forced_iname_deps = frozenset(iname for iname, tag in inames_and_tags) - iname_to_tag_requests.update(dict(inames_and_tags)) - else: - forced_iname_deps = frozenset() - - if groups["duplicate_inames_and_tags"] is not None: - duplicate_inames_and_tags = parse_iname_and_tag_list( - groups["duplicate_inames_and_tags"]) - else: - duplicate_inames_and_tags = [] - if groups["temp_var_type"] is not None: if groups["temp_var_type"]: temp_var_type = np.dtype(groups["temp_var_type"]) @@ -903,10 +878,9 @@ class LoopKernel(Record): id=self.make_unique_instruction_id( parsed_instructions, based_on=insn_id), insn_deps=insn_deps, - forced_iname_deps=forced_iname_deps, + forced_iname_deps=frozenset(), assignee=lhs, expression=rhs, temp_var_type=temp_var_type, - duplicate_inames_and_tags=duplicate_inames_and_tags, priority=priority)) elif subst_match is not None: @@ -930,7 +904,7 @@ class LoopKernel(Record): substitutions[subst_name] = SubstitutionRule( name=subst_name, - arguments=arg_names, + arguments=tuple(arg_names), expression=rhs) def parse_if_necessary(insn): @@ -1111,11 +1085,16 @@ class LoopKernel(Record): | set(self.all_inames())) def make_unique_var_name(self, based_on="var", extra_used_vars=set()): - used_vars = self.all_variable_names() | extra_used_vars + from warnings import warn + warn("make_unique_var_name is deprecated, use get_var_name_generator " + "instead", DeprecationWarning, stacklevel=2) - for var_name in _generate_unique_possibilities(based_on): - if var_name not in used_vars: - return var_name + gen = self.get_var_name_generator() + gen.add_names(extra_used_vars) + return gen(based_on) + + def get_var_name_generator(self): + return _UniqueNameGenerator(self.all_variable_names()) def make_unique_instruction_id(self, insns=None, based_on="insn", extra_used_ids=set()): if insns is None: @@ -1623,22 +1602,6 @@ class LoopKernel(Record): # }}} - def map_expressions(self, func, exclude_instructions=False): - if exclude_instructions: - new_insns = self.instructions - else: - new_insns = [insn.copy( - expression=func(insn.expression), - assignee=func(insn.assignee), - ) - for insn in self.instructions] - - return self.copy( - instructions=new_insns, - substitutions=dict( - (subst.name, subst.copy(expression=func(subst.expression))) - for subst in self.substitutions.itervalues())) - # {{{ pretty-printing def __str__(self): @@ -1711,9 +1674,15 @@ def find_all_insn_inames(kernel): insn_id_to_inames = {} insn_assignee_inames = {} + all_read_deps = {} + all_write_deps = {} + + from loopy.subst import expand_subst + kernel = expand_subst(kernel) + for insn in kernel.instructions: - read_deps = get_dependencies(insn.expression) - write_deps = get_dependencies(insn.assignee) + all_read_deps[insn.id] = read_deps = get_dependencies(insn.expression) + all_write_deps[insn.id] = write_deps = get_dependencies(insn.assignee) deps = read_deps | write_deps iname_deps = ( @@ -1748,8 +1717,7 @@ def find_all_insn_inames(kernel): # of iname deps of all writers, and add those to insn's # dependencies. - for tv_name in (get_dependencies(insn.expression) - & temp_var_names): + for tv_name in (all_read_deps[insn.id] & temp_var_names): implicit_inames = None for writer_id in writer_map[tv_name]: @@ -1874,8 +1842,7 @@ class DomainChanger: # }}} - -# {{{ dot export +# {{{ graphviz / dot export def get_dot_dependency_graph(kernel, iname_cluster=False, iname_edge=True): lines = [] diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 4b2e29a10..ce5cf8c58 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -229,6 +229,7 @@ def realize_reduction(kernel, insn_id_filter=None): new_insns = [] new_temporary_variables = kernel.temporary_variables.copy() + orig_temp_var_names = set(kernel.temporary_variables) from loopy.codegen.expression import TypeInferenceMapper type_inf_mapper = TypeInferenceMapper(kernel) @@ -240,7 +241,7 @@ def realize_reduction(kernel, insn_id_filter=None): from pymbolic import var target_var_name = kernel.make_unique_var_name("acc_"+"_".join(expr.inames), - extra_used_vars=set(new_temporary_variables)) + extra_used_vars=set(new_temporary_variables) - orig_temp_var_names) target_var = var(target_var_name) arg_dtype = type_inf_mapper(expr.expr) @@ -437,11 +438,20 @@ def duplicate_private_temporaries_for_ilp(kernel): # }}} from pymbolic import var - return (kernel - .copy(temporary_variables=new_temp_vars) - .map_expressions(ExtraInameIndexInserter( - dict((var_name, tuple(var(iname) for iname in inames)) - for var_name, inames in var_to_new_ilp_inames.iteritems())))) + eiii = ExtraInameIndexInserter( + dict((var_name, tuple(var(iname) for iname in inames)) + for var_name, inames in var_to_new_ilp_inames.iteritems())) + + + new_insns = [ + insn.copy( + assignee=eiii(insn.assignee), + expression=eiii(insn.expression)) + for insn in kernel.instructions] + + return kernel.copy( + temporary_variables=new_temp_vars, + instructions=new_insns) # }}} diff --git a/loopy/subst.py b/loopy/subst.py index 1ed9788ca..eec7d74b9 100644 --- a/loopy/subst.py +++ b/loopy/subst.py @@ -91,7 +91,7 @@ def extract_subst(kernel, subst_name, template, parameters): if urecs: if len(urecs) > 1: - raise RuntimeError("ambiguous unification of '%s' with template '%s'" + raise RuntimeError("ambiguous unification of '%s' with template '%s'" % (expr, template)) urec, = urecs @@ -155,7 +155,7 @@ def extract_subst(kernel, subst_name, template, parameters): new_substs = { subst_name: SubstitutionRule( name=subst_name, - arguments=parameters, + arguments=tuple(parameters), expression=template, )} @@ -172,27 +172,13 @@ def extract_subst(kernel, subst_name, template, parameters): -def expand_subst(kernel, subst_name=None): - if subst_name is None: - rules = kernel.substitutions - else: - rule = kernel.substitutions[subst_name] - rules = {rule.name: rule} - +def expand_subst(kernel, ctx_match=None): from loopy.symbolic import ParametrizedSubstitutor - submap = ParametrizedSubstitutor(rules) - - if subst_name: - new_substs = kernel.substitutions.copy() - del new_substs[subst_name] - else: - new_substs = {} - - return (kernel - .copy(substitutions=new_substs) - .map_expressions(submap)) - - + from loopy.context_matching import parse_stack_match + submap = ParametrizedSubstitutor(kernel.substitutions, + kernel.get_var_name_generator(), + parse_stack_match(ctx_match)) + return submap.map_kernel(kernel) # vim: foldmethod=marker diff --git a/loopy/symbolic.py b/loopy/symbolic.py index cbb664a27..860c9ec18 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -27,11 +27,11 @@ THE SOFTWARE. -from pytools import memoize, memoize_method +from pytools import memoize, memoize_method, Record import pytools.lex from pymbolic.primitives import ( - Leaf, AlgebraicLeaf, Variable as VariableBase, + Leaf, AlgebraicLeaf, Variable, CommonSubexpression) from pymbolic.mapper import ( @@ -81,14 +81,14 @@ class TypedCSE(CommonSubexpression): return dict(dtype=self.dtype) -class TaggedVariable(VariableBase): +class TaggedVariable(Variable): """This is an identifier with a tag, such as 'matrix$one', where 'one' identifies this specific use of the identifier. This mechanism may then be used to address these uses--such as by prefetching only accesses tagged a certain way. """ def __init__(self, name, tag): - VariableBase.__init__(self, name) + Variable.__init__(self, name) self.tag = tag def __getinitargs__(self): @@ -129,13 +129,8 @@ class Reduction(AlgebraicLeaf): @property @memoize_method - def untagged_inames(self): - return tuple(iname.lstrip("@") for iname in self.inames) - - @property - @memoize_method - def untagged_inames_set(self): - return set(self.untagged_inames) + def inames_set(self): + return set(self.inames) mapper_method = intern("map_reduction") @@ -157,14 +152,14 @@ class LinearSubscript(AlgebraicLeaf): # {{{ mappers with support for loopy-specific primitives class IdentityMapperMixin(object): - def map_reduction(self, expr): - return Reduction(expr.operation, expr.inames, self.rec(expr.expr)) + def map_reduction(self, expr, *args): + return Reduction(expr.operation, expr.inames, self.rec(expr.expr, *args)) - def map_tagged_variable(self, expr): + def map_tagged_variable(self, expr, *args): # leaf, doesn't change return expr - def map_loopy_function_identifier(self, expr): + def map_loopy_function_identifier(self, expr, *args): return expr map_linear_subscript = IdentityMapperBase.map_subscript @@ -217,9 +212,8 @@ class StringifyMapper(StringifyMapperBase): class DependencyMapper(DependencyMapperBase): def map_reduction(self, expr): - from pymbolic.primitives import Variable return (self.rec(expr.expr) - - set(Variable(iname) for iname in expr.untagged_inames)) + - set(Variable(iname) for iname in expr.inames)) def map_tagged_variable(self, expr): return set([expr]) @@ -257,6 +251,187 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase): # }}} +# {{{ identity mapper that expands subst rules on the fly + +def parse_tagged_name(expr): + if isinstance(expr, TaggedVariable): + return expr.name, expr.tag + elif isinstance(expr, Variable): + return expr.name, None + else: + raise RuntimeError("subst rule name not understood: %s" % expr) + +class ExpansionState(Record): + """ + :ivar stack: a tuple representing the current expansion stack, as a tuple + of (name, tag) pairs. At the top level, this should be initialized to a + tuple with the id of the calling instruction. + :ivar arg_context: a dict representing current argument values + """ + +class ExpandingIdentityMapper(IdentityMapper): + """Note: the third argument dragged around by this mapper is the + current expansion expansion state. + """ + + def __init__(self, old_subst_rules, make_unique_var_name): + self.old_subst_rules = old_subst_rules + self.make_unique_var_name = make_unique_var_name + + # maps subst rule (args, bodies) to names + self.subst_rule_registry = dict( + ((rule.arguments, rule.expression), name) + for name, rule in old_subst_rules.iteritems()) + + # maps subst rule (args, bodies) to use counts + self.subst_rule_use_count = {} + + def register_subst_rule(self, name, args, body): + """Returns a name (as a string) for a newly created substitution + rule. + """ + key = (args, body) + existing_name = self.subst_rule_registry.get(key) + + if existing_name is None: + new_name = self.make_unique_var_name(name) + self.subst_rule_registry[key] = new_name + else: + new_name = existing_name + + self.subst_rule_use_count[key] = self.subst_rule_use_count.get(key, 0) + 1 + return new_name + + def map_variable(self, expr, expn_state): + name, tag = parse_tagged_name(expr) + if name not in self.old_subst_rules: + return IdentityMapper.map_variable(self, expr, expn_state) + else: + return self.map_substitution(name, tag, (), expn_state) + + def map_call(self, expr, expn_state): + if not isinstance(expr.function, Variable): + return IdentityMapper.map_call(self, expr, expn_state) + + name, tag = parse_tagged_name(expr.function) + + if name not in self.old_subst_rules: + return IdentityMapper.map_call(self, expr, expn_state) + else: + return self.map_substitution(name, tag, expr.parameters, expn_state) + + @staticmethod + def make_new_arg_context(rule_name, arg_names, arguments, arg_context): + if len(arg_names) != len(arguments): + raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)" + % (rule_name, len(arguments), len(arg_names), )) + + from pymbolic.mapper.substitutor import make_subst_func + arg_subst_map = SubstitutionMapper(make_subst_func(arg_context)) + return dict( + (formal_arg_name, arg_subst_map(arg_value)) + for formal_arg_name, arg_value in zip(arg_names, arguments)) + + def map_substitution(self, name, tag, arguments, expn_state): + rule = self.old_subst_rules[name] + + rec_arguments = self.rec(arguments, expn_state) + new_expn_state = expn_state.copy( + stack=expn_state.stack + ((name, tag),), + arg_context=self.make_new_arg_context( + name, rule.arguments, rec_arguments, expn_state.arg_context)) + + result = self.rec(rule.expression, new_expn_state) + + new_name = self.register_subst_rule(name, rule.arguments, result) + + if tag is None: + sym = Variable(new_name) + else: + sym = TaggedVariable(new_name, tag) + + if arguments: + return sym(*rec_arguments) + else: + return sym + + def __call__(self, expr, insn_id): + if insn_id is not None: + stack = (insn_id,) + else: + stack = () + + return IdentityMapper.__call__(self, expr, ExpansionState( + stack=stack, arg_context={})) + + def get_new_substitutions(self): + from loopy.kernel import SubstitutionRule + + result = {} + for key, name in self.subst_rule_registry.iteritems(): + args, body = key + + if self.subst_rule_use_count.get(key, 0): + result[name] = SubstitutionRule( + name=name, + arguments=args, + expression=body) + + return result + + def map_kernel(self, kernel): + new_insns = [ + insn.copy( + assignee=self(insn.assignee, insn.id), + expression=self(insn.expression, insn.id)) + for insn in kernel.instructions] + + return kernel.copy( + substitutions=self.get_new_substitutions(), + instructions=new_insns) + + +# }}} + +# {{{ parametrized substitutor + +class ParametrizedSubstitutor(ExpandingIdentityMapper): + def __init__(self, rules, make_unique_var=None, ctx_match=None): + ExpandingIdentityMapper.__init__(self, rules, make_unique_var) + + if ctx_match is None: + from loopy.context_matching import AllStackMatch + ctx_match = AllStackMatch() + + self.ctx_match = ctx_match + + def map_substitution(self, name, tag, arguments, expn_state): + new_stack = expn_state.stack + ((name, tag),) + if self.ctx_match(new_stack): + # expand + rule = self.old_subst_rules[name] + + new_expn_state = expn_state.copy( + stack=new_stack, + arg_context=self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context)) + + result = self.rec(rule.expression, new_expn_state) + + # substitute in argument values + from pymbolic.mapper.substitutor import make_subst_func + subst_map = SubstitutionMapper(make_subst_func( + new_expn_state.arg_context)) + + return subst_map(result) + + else: + # do not expand + return ExpandingIdentityMapper.map_substitution( + self, name, tag, arguments, expn_state) + +# }}} + # {{{ functions to primitives, parsing class VarToTaggedVarMapper(IdentityMapper): @@ -343,7 +518,7 @@ class FunctionToPrimitiveMapper(IdentityMapper): return Reduction(operation, tuple(processed_inames), red_expr) -# {{{ parser extension +# {{{ customization to pymbolic parser _open_dbl_bracket = intern("open_dbl_bracket") _close_dbl_bracket = intern("close_dbl_bracket") @@ -392,26 +567,6 @@ def parse(expr_str): # }}} -# {{{ reduction loop splitter - -class ReductionLoopSplitter(IdentityMapper): - def __init__(self, old_iname, outer_iname, inner_iname): - self.old_iname = old_iname - self.outer_iname = outer_iname - self.inner_iname = inner_iname - - def map_reduction(self, expr): - if self.old_iname in expr.inames: - new_inames = list(expr.inames) - new_inames.remove(self.old_iname) - new_inames.extend([self.outer_iname, self.inner_iname]) - return Reduction(expr.operation, tuple(new_inames), - expr.expr) - else: - return IdentityMapper.map_reduction(self, expr) - -# }}} - # {{{ coefficient collector class CoefficientCollector(RecursiveMapper): @@ -641,134 +796,13 @@ class IndexVariableFinder(CombineMapper): def map_reduction(self, expr): result = self.rec(expr.expr) - if not (expr.untagged_inames_set & result): + if not (expr.inames_set & result): raise RuntimeError("reduction '%s' does not depend on " "reduction inames (%s)" % (expr, ",".join(expr.inames))) if self.include_reduction_inames: return result else: - return result - expr.untagged_inames_set - -# }}} - -# {{{ substitution callback mapper - -class SubstitutionCallbackMapper(IdentityMapper): - @staticmethod - def parse_filter(filt): - if not isinstance(filt, tuple): - components = filt.split("$") - if len(components) == 1: - return (components[0], None) - elif len(components) == 2: - return tuple(components) - else: - raise RuntimeError("too many components in '%s'" % filt) - else: - if len(filt) != 2: - raise RuntimeError("substitution name filters " - "may have at most two components") - - return filt - - def __init__(self, names_filter, func): - if names_filter is not None: - new_names_filter = [] - for filt in names_filter: - new_names_filter.append(self.parse_filter(filt)) - - self.names_filter = new_names_filter - else: - self.names_filter = names_filter - - self.func = func - - def parse_name(self, expr): - from pymbolic.primitives import Variable - if isinstance(expr, TaggedVariable): - e_name, e_tag = expr.name, expr.tag - elif isinstance(expr, Variable): - e_name, e_tag = expr.name, None - else: - return None - - if self.names_filter is not None: - for filt_name, filt_tag in self.names_filter: - if e_name == filt_name: - if filt_tag is None or filt_tag == e_tag: - return e_name, e_tag - else: - return e_name, e_tag - - return None - - def map_variable(self, expr): - parsed_name = self.parse_name(expr) - if parsed_name is None: - return getattr(IdentityMapper, expr.mapper_method)(self, expr) - - name, tag = parsed_name - - result = self.func(expr, name, tag, (), self.rec) - if result is None: - return getattr(IdentityMapper, expr.mapper_method)(self, expr) - else: - return result - - map_tagged_variable = map_variable - - def map_call(self, expr): - from pymbolic.primitives import Lookup - if isinstance(expr.function, Lookup): - raise RuntimeError("dotted name '%s' not allowed as " - "function identifier" % expr.function) - - parsed_name = self.parse_name(expr.function) - - if parsed_name is None: - return IdentityMapper.map_call(self, expr) - - name, tag = parsed_name - - result = self.func(expr, name, tag, expr.parameters, self.rec) - if result is None: - return IdentityMapper.map_call(self, expr) - else: - return result - -# }}} - -# {{{ parametrized substitutor - -class ParametrizedSubstitutor(object): - def __init__(self, rules, one_level=False): - self.rules = rules - self.one_level = one_level - - def __call__(self, expr): - level = [0] - - def expand_if_known(expr, name, instance, args, rec): - if self.one_level and level[0] > 0: - return None - - rule = self.rules[name] - if len(rule.arguments) != len(args): - raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)" - % (name, len(args), len(rule.arguments), )) - - from pymbolic.mapper.substitutor import make_subst_func - subst_map = SubstitutionMapper(make_subst_func( - dict(zip(rule.arguments, args)))) - - level[0] += 1 - result = rec(subst_map(rule.expression)) - level[0] -= 1 - - return result - - scm = SubstitutionCallbackMapper(self.rules.keys(), expand_if_known) - return scm(expr) + return result - expr.inames_set # }}} diff --git a/test/test_linalg.py b/test/test_linalg.py index 75604febb..0bd22021a 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -29,7 +29,6 @@ import numpy as np import numpy.linalg as la import pyopencl as cl import pyopencl.array as cl_array -import pyopencl.clrandom as cl_random import loopy as lp from pyopencl.tools import pytest_generate_tests_for_pyopencl \ diff --git a/test/test_loopy.py b/test/test_loopy.py index 84f8e0c21..8798ac6ec 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -39,6 +39,40 @@ __all__ = ["pytest_generate_tests", +def test_complicated_subst(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel(ctx.devices[0], + "{[i]: 0<=i<n}", + """ + f(x) := x*a[x] + g(x) := 12 + f(x) + h(x) := 1 + g(x) + 20*g$two(x) + + a[i] = h$one(i) * h$two(i) + """, + [ + lp.GlobalArg("a", np.float32, shape=("n",)), + lp.ValueArg("n", np.int32), + ]) + + from loopy.subst import expand_subst + knl = expand_subst(knl, "g$two < h$two") + + print knl + + sr_keys = knl.substitutions.keys() + for letter, how_many in [ + ("f", 1), + ("g", 1), + ("h", 2) + ]: + substs_with_letter = sum(1 for k in sr_keys if k.startswith(letter)) + assert substs_with_letter == how_many + + + + def test_type_inference_no_artificial_doubles(ctx_factory): ctx = ctx_factory() @@ -135,11 +169,13 @@ def test_owed_barriers(ctx_factory): knl = lp.make_kernel(ctx.devices[0], "{[i]: 0<=i<100}", [ - "[i:l.0] <float32> z[i] = a[i]" + "<float32> z[i] = a[i]" ], [lp.GlobalArg("a", np.float32, shape=(100,))] ) + knl = lp.tag_inames(knl, dict(i="l.0")) + kernel_gen = lp.generate_loop_schedules(knl) kernel_gen = lp.check_kernels(kernel_gen) @@ -156,11 +192,13 @@ def test_wg_too_small(ctx_factory): knl = lp.make_kernel(ctx.devices[0], "{[i]: 0<=i<100}", [ - "[i:l.0] <float32> z[i] = a[i] {id=copy}" + "<float32> z[i] = a[i] {id=copy}" ], [lp.GlobalArg("a", np.float32, shape=(100,))], local_sizes={0: 16}) + knl = lp.tag_inames(knl, dict(i="l.0")) + kernel_gen = lp.generate_loop_schedules(knl) kernel_gen = lp.check_kernels(kernel_gen) @@ -242,7 +280,7 @@ def test_multi_cse(ctx_factory): knl = lp.make_kernel(ctx.devices[0], "{[i]: 0<=i<100}", [ - "[i] <float32> z[i] = a[i] + a[i]**2" + "<float32> z[i] = a[i] + a[i]**2" ], [lp.GlobalArg("a", np.float32, shape=(100,))], local_sizes={0: 16}) @@ -816,7 +854,7 @@ def test_ilp_write_race_detection_global(ctx_factory): "[n] -> {[i,j]: 0<=i,j<n }", ], [ - "[j:ilp] a[i] = 5+i+j", + "a[i] = 5+i+j", ], [ lp.GlobalArg("a", np.float32), @@ -824,6 +862,8 @@ def test_ilp_write_race_detection_global(ctx_factory): ], assumptions="n>=1") + knl = lp.tag_inames(knl, dict(j="ilp")) + from loopy.check import WriteRaceConditionError import pytest with pytest.raises(WriteRaceConditionError): @@ -838,10 +878,12 @@ def test_ilp_write_race_avoidance_local(ctx_factory): knl = lp.make_kernel(ctx.devices[0], "{[i,j]: 0<=i<16 and 0<=j<17 }", [ - "[i:l.0, j:ilp] <> a[i] = 5+i+j", + "<> a[i] = 5+i+j", ], []) + knl = lp.tag_inames(knl, dict(i="l.0", j="ilp")) + for k in lp.generate_loop_schedules(knl): assert k.temporary_variables["a"].shape == (16,17) @@ -854,10 +896,12 @@ def test_ilp_write_race_avoidance_private(ctx_factory): knl = lp.make_kernel(ctx.devices[0], "{[j]: 0<=j<16 }", [ - "[j:ilp] <> a = 5+j", + "<> a = 5+j", ], []) + knl = lp.tag_inames(knl, dict(j="ilp")) + for k in lp.generate_loop_schedules(knl): assert k.temporary_variables["a"].shape == (16,) -- GitLab