diff --git a/loopy/__init__.py b/loopy/__init__.py index 0043a21b1ba9082b6b076800580f27eb651e1e86..913479d789e1217a00dd2994c7df1821c368d506 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -22,7 +22,7 @@ class LoopyAdvisory(UserWarning): from loopy.kernel import ScalarArg, ArrayArg, ConstantArrayArg, ImageArg -from loopy.kernel import AutoFitLocalIndexTag, get_dot_dependency_graph, LoopKernel +from loopy.kernel import AutoFitLocalIndexTag, get_dot_dependency_graph from loopy.subst import extract_subst, expand_subst from loopy.cse import precompute from loopy.preprocess import preprocess_kernel, realize_reduction @@ -453,15 +453,31 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None, "may not contain a subscript") assert isinstance(parsed_var_name.aggregate, Variable) - var_name = parsed_var_name.aggregate.name footprint_subscripts = [parsed_var_name.index] + parsed_var_name = parsed_var_name.aggregate else: raise ValueError("var_name must either be a variable name or a subscript") # }}} + # {{{ fish out tag + + from loopy.symbolic import TaggedVariable + if isinstance(parsed_var_name, TaggedVariable): + var_name = parsed_var_name.name + tag = parsed_var_name.tag + else: + var_name = parsed_var_name.name + tag = None + + # }}} + + c_name = var_name + if tag is not None: + c_name = c_name + "_" + tag + if rule_name is None: - rule_name = kernel.make_unique_var_name("%s_fetch" % var_name) + rule_name = kernel.make_unique_var_name("%s_fetch" % c_name) newly_created_vars = set([rule_name]) @@ -469,7 +485,7 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None, parameters = [] for i in range(arg.dimensions): - based_on = "%s_dim_%d" % (var_name, i) + based_on = "%s_dim_%d" % (c_name, i) if dim_arg_names is not None and i < len(dim_arg_names): based_on = dim_arg_names[i] @@ -479,7 +495,7 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None, parameters.append(par_name) from pymbolic import var - uni_template = var(var_name) + uni_template = parsed_var_name if len(parameters) > 1: uni_template = uni_template[tuple(var(par_name) for par_name in parameters)] elif len(parameters) == 1: @@ -535,5 +551,4 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None, - # vim: foldmethod=marker diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index b871721d2730abc787ff6c631cf97da1c52563b0..adbd3e55afcdeaca669130e4f51cfac95afc8658 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -116,6 +116,9 @@ class LoopyCCodeMapper(CCodeMapper): else: return CCodeMapper.map_variable(self, expr, prec) + def map_tagged_variable(self, expr, enclosing_prec): + return expr.name + def map_subscript(self, expr, enclosing_prec): from pymbolic.primitives import Variable if not isinstance(expr.aggregate, Variable): diff --git a/loopy/cse.py b/loopy/cse.py index 6c8455b9a2a9c9e414a2b40e834c111154b923ac..398a8089722235ae7647b8db39819c7b9d4bee7e 100644 --- a/loopy/cse.py +++ b/loopy/cse.py @@ -359,7 +359,7 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[], from loopy.symbolic import SubstitutionCallbackMapper c_subst_name = subst_name.replace(".", "_") - subst_name, subst_instance = SubstitutionCallbackMapper.parse_filter(subst_name) + subst_name, subst_tag = SubstitutionCallbackMapper.parse_filter(subst_name) from loopy.kernel import parse_tag default_tag = parse_tag(default_tag) @@ -407,7 +407,7 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[], subst_expander = ParametrizedSubstitutor(rules_except_mine, one_level=True) - def gather_substs(expr, name, instance, args, rec): + def gather_substs(expr, name, tag, args, rec): if subst_name != name: if name in subst_expander.rules: # We can't deal with invocations that involve other substitution's @@ -424,7 +424,7 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[], else: return None - if subst_instance != instance: + if subst_tag != tag: # use fall-back identity mapper return None @@ -571,25 +571,11 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[], # }}} - # {{{ set up temp variable + # {{{ set up compute insn target_var_name = kernel.make_unique_var_name(based_on=c_subst_name, extra_used_vars=newly_created_var_names) - from loopy.kernel import TemporaryVariable - - new_temporary_variables = kernel.temporary_variables.copy() - new_temporary_variables[target_var_name] = TemporaryVariable( - name=target_var_name, - dtype=np.dtype(dtype), - base_indices=(0,)*len(non1_storage_shape), - shape=non1_storage_shape, - is_local=None) - - # }}} - - # {{{ set up compute insn - assignee = var(target_var_name) if non1_storage_axis_names: @@ -620,8 +606,8 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[], left_unused_subst_rule_invocations = [False] - def do_substs(expr, name, instance, args, rec): - if instance != subst_instance: + def do_substs(expr, name, tag, args, rec): + if tag != subst_tag: left_unused_subst_rule_invocations[0] = True return expr @@ -730,10 +716,30 @@ def precompute(kernel, subst_name, dtype, sweep_inames=[], # }}} + # {{{ fill out new_iname_to_tag + new_iname_to_tag = kernel.iname_to_tag.copy() for arg_name in non1_storage_axis_names: new_iname_to_tag[arg_name] = storage_axis_name_to_tag[arg_name] + # }}} + + # {{{ set up temp variable + + from loopy.kernel import TemporaryVariable + + new_temporary_variables = kernel.temporary_variables.copy() + temp_var = TemporaryVariable( + name=target_var_name, + dtype=np.dtype(dtype), + base_indices=(0,)*len(non1_storage_shape), + shape=non1_storage_shape, + is_local=None) + + new_temporary_variables[target_var_name] = temp_var + + # }}} + return kernel.copy( domain=new_domain, instructions=new_insns, diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 1ed6eb4fc3d21df25363acb62236fc05cf2200fd..e20ddbec296fb83c87b3e9f95a02508a895fced2 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -4,7 +4,9 @@ from __future__ import division from pytools import memoize, memoize_method -from pymbolic.primitives import AlgebraicLeaf +from pymbolic.primitives import ( + AlgebraicLeaf, Variable as VariableBase) + from pymbolic.mapper import ( CombineMapper as CombineMapperBase, IdentityMapper as IdentityMapperBase, @@ -29,6 +31,19 @@ from islpy import dim_type # {{{ loopy-specific primitives +class TaggedVariable(VariableBase): + def __init__(self, name, tag): + VariableBase.__init__(self, name) + self.tag = tag + + def __getinitargs__(self): + return self.name, self.tag + + def stringifier(self): + return StringifyMapper + + mapper_method = intern("map_tagged_variable") + class Reduction(AlgebraicLeaf): def __init__(self, operation, inames, expr): assert isinstance(inames, tuple) @@ -77,6 +92,10 @@ class IdentityMapperMixin(object): def map_reduction(self, expr): return Reduction(expr.operation, expr.inames, self.rec(expr.expr)) + def map_tagged_variable(self, expr): + # leaf, doesn't change + return expr + class IdentityMapper(IdentityMapperBase, IdentityMapperMixin): pass @@ -99,12 +118,18 @@ class StringifyMapper(StringifyMapperBase): return "reduce(%s, [%s], %s)" % ( expr.operation, ", ".join(expr.inames), expr.expr) + def map_tagged_variable(self, expr, prec): + return "%s$%s" % (expr.name, expr.tag) + class DependencyMapper(DependencyMapperBase): def map_reduction(self, expr): from pymbolic.primitives import Variable return (self.rec(expr.expr) - set(Variable(iname) for iname in expr.untagged_inames)) + def map_tagged_variable(self, expr): + return set([expr]) + class UnidirectionalUnifier(UnidirectionalUnifierBase): def map_reduction(self, expr, other, unis): if not isinstance(other, type(expr)): @@ -115,10 +140,35 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase): return self.rec(expr.expr, other.expr, unis) + def map_tagged_variable(self, expr, other, urecs): + new_uni_record = self.unification_record_from_equation( + expr, other) + if new_uni_record is None: + # Check if the variables match literally--that's ok, too. + if (isinstance(other, TaggedVariable) + and expr.name == other.name + and expr.tag == other.tag + and expr.name not in self.lhs_mapping_candidates): + return urecs + else: + return [] + else: + from pymbolic.mapper.unifier import unify_many + return unify_many(urecs, new_uni_record) + # }}} # {{{ functions to primitives, parsing +class VarToTaggedVarMapper(IdentityMapper): + def map_variable(self, expr): + dollar_idx = expr.name.find("$") + if dollar_idx == -1: + return expr + else: + return TaggedVariable(expr.name[:dollar_idx], + expr.name[dollar_idx+1:]) + class FunctionToPrimitiveMapper(IdentityMapper): """Looks for invocations of a function called 'cse' or 'reduce' and turns those into the actual pymbolic primitives used for that. @@ -196,7 +246,8 @@ class FunctionToPrimitiveMapper(IdentityMapper): def parse(expr_str): from pymbolic import parse - return FunctionToPrimitiveMapper()(parse(expr_str)) + return VarToTaggedVarMapper()( + FunctionToPrimitiveMapper()(parse(expr_str))) # }}} @@ -273,6 +324,8 @@ class CoefficientCollector(RecursiveMapper): def map_variable(self, expr): return {expr.name: 1} + map_tagged_variable = map_variable + def map_subscript(self, expr): raise RuntimeError("cannot gather coefficients--indirect addressing in use") @@ -454,13 +507,13 @@ class SubstitutionCallbackMapper(IdentityMapper): @staticmethod def parse_filter(filt): if not isinstance(filt, tuple): - dotted_components = filt.split(".") - if len(dotted_components) == 1: - return (dotted_components[0], None) - elif len(dotted_components) == 2: - return tuple(dotted_components) + components = filt.split("$") + if len(components) == 1: + return (components[0], None) + elif len(components) == 2: + return tuple(components) else: - raise RuntimeError("too many dotted components in '%s'" % filt) + raise RuntimeError("too many components in '%s'" % filt) else: if len(filt) != 2: raise RuntimeError("substitution name filters " @@ -481,60 +534,47 @@ class SubstitutionCallbackMapper(IdentityMapper): self.func = func def parse_name(self, expr): - from pymbolic.primitives import Variable, Lookup + from pymbolic.primitives import Variable if isinstance(expr, Variable): - e_name, e_instance = expr.name, None - elif isinstance(expr, Lookup): - if not isinstance(expr.aggregate, Variable): - return None - e_name, e_instance = expr.aggregate.name, expr.name + e_name, e_tag = expr.name, None + elif isinstance(expr, TaggedVariable): + e_name, e_tag = expr.name, expr.tag else: return None if self.names_filter is not None: - for filt_name, filt_instance in self.names_filter: + for filt_name, filt_tag in self.names_filter: if e_name == filt_name: - if filt_instance is None or filt_instance == e_instance: - return e_name, e_instance + if filt_tag is None or filt_tag == e_tag: + return e_name, e_tag else: - return e_name, e_instance + return e_name, e_tag return None def map_variable(self, expr): parsed_name = self.parse_name(expr) if parsed_name is None: - return IdentityMapper.map_variable(self, expr) + return getattr(IdentityMapper, expr.mapper_method)(self, expr) - name, instance = parsed_name + name, tag = parsed_name - result = self.func(expr, name, instance, (), self.rec) + result = self.func(expr, name, tag, (), self.rec) if result is None: - return IdentityMapper.map_variable(self, expr) + return getattr(IdentityMapper, expr.mapper_method)(self, expr) else: return result - def map_lookup(self, expr): - parsed_name = self.parse_name(expr) - if parsed_name is None: - return IdentityMapper.map_lookup(self, expr) - - name, instance = parsed_name - - result = self.func(expr, name, instance, (), self.rec) - if result is None: - return IdentityMapper.map_lookup(self, expr) - else: - return result + map_tagged_variable = map_variable def map_call(self, expr): parsed_name = self.parse_name(expr.function) if parsed_name is None: return IdentityMapper.map_call(self, expr) - name, instance = parsed_name + name, tag = parsed_name - result = self.func(expr, name, instance, expr.parameters, self.rec) + result = self.func(expr, name, tag, expr.parameters, self.rec) if result is None: return IdentityMapper.map_call(self, expr) else: @@ -601,6 +641,13 @@ class PrimeAdder(IdentityMapper): else: return expr + def map_tagged_variable(self, expr): + if expr.name in self.which_vars: + return TaggedVariable(expr.name+"'", expr.tag) + else: + return expr + + # }}} @memoize