diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py index 890b99cff872efc3db59e7db70b936780b8f16c2..742483aed9e02264c5ac65cfb40ee8ef0dbb8ee5 100644 --- a/grudge/symbolic/compiler.py +++ b/grudge/symbolic/compiler.py @@ -1112,6 +1112,12 @@ class CodeGenerationState(Record): .. attribute:: generating_discr_code """ + def get_expr_to_var(self, compiler): + if self.generating_discr_code: + return compiler.discr_expr_to_var + else: + return compiler.eval_expr_to_var + def get_code_list(self, compiler): if self.generating_discr_code: return compiler.discr_code @@ -1130,9 +1136,10 @@ class OperatorCompiler(mappers.IdentityMapper): self.discr_code = [] self.discr_scope_names_created = set() self.discr_scope_names_copied_to_eval = set() + self.discr_expr_to_var = {} self.eval_code = [] - self.expr_to_var = {} + self.eval_expr_to_var = {} self.assigned_names = set() @@ -1260,8 +1267,10 @@ class OperatorCompiler(mappers.IdentityMapper): return var(expr_name) else: + expr_to_var = codegen_state.get_expr_to_var(self) + try: - return self.expr_to_var[expr.child] + return expr_to_var[expr.child] except KeyError: priority = getattr(expr, "priority", 0) @@ -1271,7 +1280,7 @@ class OperatorCompiler(mappers.IdentityMapper): codegen_state, rec_child, priority=priority, prefix=expr.prefix) - self.expr_to_var[expr.child] = cse_var + expr_to_var[expr.child] = cse_var return cse_var def map_operator_binding(self, expr, codegen_state, name_hint=None): @@ -1308,8 +1317,9 @@ class OperatorCompiler(mappers.IdentityMapper): for par in expr.parameters])) def map_ref_diff_op_binding(self, expr, codegen_state): + expr_to_var = codegen_state.get_expr_to_var(self) try: - return self.expr_to_var[expr] + return expr_to_var[expr] except KeyError: all_diffs = [diff for diff in self.diff_ops @@ -1332,23 +1342,25 @@ class OperatorCompiler(mappers.IdentityMapper): from pymbolic import var for n, d in zip(names, all_diffs): - self.expr_to_var[d] = var(n) + expr_to_var[d] = var(n) - return self.expr_to_var[expr] + return expr_to_var[expr] def map_rank_data_swap_binding(self, expr, codegen_state, name_hint): + expr_to_var = codegen_state.get_expr_to_var(self) + try: - return self.expr_to_var[expr] + return expr_to_var[expr] except KeyError: field = self.rec(expr.field, codegen_state) name = self.name_gen("raw_rank%02d_bdry_data" % expr.op.i_remote_part) field_insn = RankDataSwapAssign(name=name, field=field, op=expr.op) codegen_state.get_code_list(self).append(field_insn) field_var = Variable(field_insn.name) - self.expr_to_var[expr] = self.assign_to_new_var(codegen_state, + expr_to_var[expr] = self.assign_to_new_var(codegen_state, expr.op(field_var), prefix=name_hint) - return self.expr_to_var[expr] + return expr_to_var[expr] # }}}