diff --git a/loopy/__init__.py b/loopy/__init__.py index 7fd63dff33fc9693956f6255908bb2684f0ab765..843356bf363f8c918a840daa879676cbe8a8f0ee 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -62,11 +62,12 @@ def make_kernel(*args, **kwargs): newly_created_vars = set() - from loopy.symbolic import CSESubstitutor - cse_sub = CSESubstitutor(knl.cses) + from loopy.symbolic import ParametrizedSubstitutor + cse_sub = ParametrizedSubstitutor(knl.cses, wrap_cse=True) + subst_sub = ParametrizedSubstitutor(knl.substitutions, wrap_cse=False) for insn in knl.instructions: - insn = insn.copy(expression=cse_sub(insn.expression)) + insn = insn.copy(expression=subst_sub(cse_sub(insn.expression))) # {{{ sanity checking diff --git a/loopy/kernel.py b/loopy/kernel.py index 53ef81e681f91091d89724e66b4e5ed9141d30ac..6af2b8726e453569784b0419b9e4c2dd2d8de96b 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -449,6 +449,7 @@ class LoopKernel(Record): :ivar iname_to_tag_requests: :ivar cses: a mapping from CSE names to tuples (arg_names, expr). + :ivar substitutions: a mapping from CSE names to tuples (arg_names, expr). """ def __init__(self, device, domain, instructions, args=None, schedule=None, @@ -457,7 +458,7 @@ class LoopKernel(Record): iname_slab_increments={}, temporary_variables={}, local_sizes={}, - iname_to_tag={}, iname_to_tag_requests=None, cses={}): + iname_to_tag={}, iname_to_tag_requests=None, cses={}, substitutions={}): """ :arg domain: a :class:`islpy.BasicSet`, or a string parseable to a basic set by the isl. Example: "{[i,j]: 0<=i < 10 and 0<= j < 9}" @@ -539,7 +540,7 @@ class LoopKernel(Record): from loopy.symbolic import FunctionToPrimitiveMapper rhs = FunctionToPrimitiveMapper()(parse(groups["rhs"])) - if label.lower() != "cse": + if label.lower() not in ["cse", "subst"]: if groups["insn_deps"] is not None: insn_deps = set(dep.strip() for dep in groups["insn_deps"].split(",")) else: @@ -574,11 +575,11 @@ class LoopKernel(Record): duplicate_inames_and_tags=duplicate_inames_and_tags)) else: if groups["iname_deps_and_tags"] is not None: - raise RuntimeError("CSEs cannot declare iname dependencies") + raise RuntimeError("CSEs/substitutions cannot declare iname dependencies") if groups["insn_deps"] is not None: - raise RuntimeError("CSEs cannot declare instruction dependencies") + raise RuntimeError("CSEs/substitutions cannot declare instruction dependencies") if groups["temp_var_type"] is not None: - raise RuntimeError("CSEs cannot declare temporary storage") + raise RuntimeError("CSEs/substitutions cannot declare temporary storage") from pymbolic.primitives import Variable, Call @@ -598,12 +599,18 @@ class LoopKernel(Record): else: raise RuntimeError("CSEs cannot declare temporary storage") - cses[cse_name] = (arg_names, rhs) + if label.lower() == "cse": + cses[cse_name] = (arg_names, rhs) + else: + substitutions[cse_name] = (arg_names, rhs) # }}} insns = [] + cses = cses.copy() + substituions = substitutions.copy() + for insn in instructions: # must construct list one-by-one to facilitate unique id generation parse_if_necessary(insn) @@ -635,7 +642,7 @@ class LoopKernel(Record): local_sizes=local_sizes, iname_to_tag=iname_to_tag, iname_to_tag_requests=iname_to_tag_requests, - cses=cses) + cses=cses, substitutions=substitutions) def make_unique_instruction_id(self, insns=None, based_on="insn", extra_used_ids=set()): if insns is None: diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 21ff0bd16464c3e8c2bdf45a1996cfafc10e60b6..49e9dfdbf4616a73cc72f3f29142f7344f3d664f 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -602,14 +602,16 @@ class VariableFetchCSEMapper(IdentityMapper): # }}} -# {{{ CSE substitutor +# {{{ parametrized substitutor -class CSESubstitutor(IdentityMapper): - def __init__(self, cses): +class ParametrizedSubstitutor(IdentityMapper): + def __init__(self, cses, wrap_cse): """ :arg cses: a mapping from CSE names to tuples (arg_names, expr). + :arg wrap_cse: flag: wrap substituted expressions in CSEs """ self.cses = cses + self.wrap_cse = wrap_cse def map_variable(self, expr): if expr.name not in self.cses: @@ -620,8 +622,11 @@ class CSESubstitutor(IdentityMapper): raise RuntimeError("CSE '%s' must be invoked with %d arguments" % (expr.name, len(arg_names))) - from pymbolic.primitives import CommonSubexpression - return CommonSubexpression(cse_expr, expr.name) + if self.wrap_cse: + from pymbolic.primitives import CommonSubexpression + return CommonSubexpression(cse_expr, expr.name) + else: + return cse_expr def map_call(self, expr): from pymbolic.primitives import Variable, CommonSubexpression @@ -639,7 +644,12 @@ class CSESubstitutor(IdentityMapper): subst_map = SubstitutionMapper(make_subst_func( dict(zip(arg_names, expr.parameters)))) - return CommonSubexpression(subst_map(cse_expr), cse_name) + cse_expr = subst_map(cse_expr) + + if self.wrap_cse: + return CommonSubexpression(cse_expr, cse_name) + else: + return cse_expr # }}}