diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py index 4a58961b66966485294092ba9baa60476b209afd..972e4056729808778e5875ba786ef25c1348f362 100644 --- a/grudge/symbolic/compiler.py +++ b/grudge/symbolic/compiler.py @@ -405,7 +405,12 @@ class Code(object): for dep in insn.get_dependencies(): try: - writer = var_to_writer[dep.name] + if isinstance(dep, Subscript): + dep_name = dep.aggregate.name + else: + dep_name = dep.name + + writer = var_to_writer[dep_name] except KeyError: # input variables won't be found pass @@ -604,6 +609,11 @@ def aggregate_assignments(inf_mapper, instructions, result, # {{{ aggregation helpers def get_complete_origins_set(insn, skip_levels=0): + try: + return insn_to_origins_cache[insn] + except KeyError: + pass + if skip_levels < 0: skip_levels = 0 @@ -617,6 +627,8 @@ def aggregate_assignments(inf_mapper, instructions, result, result |= get_complete_origins_set( dep_origin, skip_levels-1) + insn_to_origins_cache[insn] = result + return result var_assignees_cache = {} @@ -646,6 +658,8 @@ def aggregate_assignments(inf_mapper, instructions, result, # {{{ main aggregation pass + insn_to_origins_cache = {} + origins_map = dict( (assignee, insn) for insn in instructions diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index 2b0b2242de11e6180d7889a2ec87a3b9e18a08f7..c2fba3d6a757fa9ed80babf4f7315707bf4e7858 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -933,9 +933,13 @@ class QuadratureCheckerAndRemover(CSECachingMapperMixin, IdentityMapper): # {{{ simplification / optimization class ConstantToNumpyConversionMapper( + CSECachingMapperMixin, pymbolic.mapper.constant_converter.ConstantToNumpyConversionMapper, IdentityMapperMixin): - pass + map_common_subexpression_uncached = ( + pymbolic.mapper.constant_converter + .ConstantToNumpyConversionMapper + .map_common_subexpression) class CommutativeConstantFoldingMapper(