diff --git a/loopy/__init__.py b/loopy/__init__.py index dc02512ba5b5d75c8d73ca5f7b8e8791975a8d27..e1819f1f994969f29b545a0272016867bcc48f9b 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1207,8 +1207,12 @@ def _fix_parameter(kernel, name, value): from pymbolic.mapper.substitutor import make_subst_func subst_func = make_subst_func({name: value}) - from loopy.symbolic import SubstitutionMapper + from loopy.symbolic import SubstitutionMapper, PartialEvaluationMapper subst_map = SubstitutionMapper(subst_func) + ev_map = PartialEvaluationMapper() + + def map_expr(expr): + return ev_map(subst_map(expr)) from loopy.kernel.array import ArrayBase new_args = [] @@ -1220,11 +1224,11 @@ def _fix_parameter(kernel, name, value): if not isinstance(arg, ArrayBase): new_args.append(arg) else: - new_args.append(arg.map_exprs(subst_map)) + new_args.append(arg.map_exprs(map_expr)) new_temp_vars = {} for tv in kernel.temporary_variables.itervalues(): - new_temp_vars[tv.name] = tv.map_exprs(subst_map) + new_temp_vars[tv.name] = tv.map_exprs(map_expr) from loopy.context_matching import parse_stack_match within = parse_stack_match(None) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index fa0b91be8ac4d3ec0c721c4164fc983cf4643b56..acf361dfdf0c5bc8a7d5d3fe78b28442f4a5c8bf 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -131,8 +131,9 @@ def expand_defines_in_expr(expr, defines): else: return None - from loopy.symbolic import SubstitutionMapper - return SubstitutionMapper(subst_func)(expr) + from loopy.symbolic import SubstitutionMapper, PartialEvaluationMapper + return PartialEvaluationMapper()( + SubstitutionMapper(subst_func)(expr)) # }}} @@ -746,15 +747,24 @@ def expand_defines_in_shapes(kernel, defines): from loopy.kernel.array import ArrayBase from loopy.kernel.creation import expand_defines_in_expr + def expr_map(expr): + return expand_defines_in_expr(expr, defines) + processed_args = [] for arg in kernel.args: if isinstance(arg, ArrayBase): - arg = arg.map_exprs( - lambda expr: expand_defines_in_expr(expr, defines)) + arg = arg.map_exprs(expr_map) processed_args.append(arg) - return kernel.copy(args=processed_args) + processed_temp_vars = {} + for tv in kernel.temporary_variables.itervalues(): + processed_temp_vars[tv.name] = tv.map_exprs(expr_map) + + return kernel.copy( + args=processed_args, + temporary_variables=processed_temp_vars, + ) # }}} @@ -867,7 +877,7 @@ def resolve_wildcard_deps(knl): if match_count == 0: # Uh, best we can do - new_deps.append(dep) + new_deps.add(dep) insn = insn.copy(insn_deps=frozenset(new_deps)) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 81837f98d170a35125a69c721327610e57371a3c..22291c7fb32990d733a2abc0d6f9065b7d88b36e 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -38,6 +38,8 @@ from pymbolic.mapper import ( WalkMapper as WalkMapperBase, CallbackMapper as CallbackMapperBase, ) +from pymbolic.mapper.evaluator import \ + EvaluationMapper as EvaluationMapperBase from pymbolic.mapper.substitutor import \ SubstitutionMapper as SubstitutionMapperBase from pymbolic.mapper.stringifier import \ @@ -170,6 +172,11 @@ class IdentityMapper(IdentityMapperBase, IdentityMapperMixin): pass +class PartialEvaluationMapper(EvaluationMapperBase, IdentityMapperMixin): + def map_variable(self, expr): + return expr + + class WalkMapper(WalkMapperBase): def map_reduction(self, expr, *args): if not self.visit(expr):