From 0995cf1dfa0670b2d6b2645ab5bce2fce703069b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 26 Apr 2021 18:23:41 -0500 Subject: [PATCH] Fix 'compiler' for the case that a subexpression occurs both in discr-cached and non-discr-cached code --- grudge/symbolic/compiler.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py index 890b99cf..742483ae 100644 --- a/grudge/symbolic/compiler.py +++ b/grudge/symbolic/compiler.py @@ -1112,6 +1112,12 @@ class CodeGenerationState(Record): .. attribute:: generating_discr_code """ + def get_expr_to_var(self, compiler): + if self.generating_discr_code: + return compiler.discr_expr_to_var + else: + return compiler.eval_expr_to_var + def get_code_list(self, compiler): if self.generating_discr_code: return compiler.discr_code @@ -1130,9 +1136,10 @@ class OperatorCompiler(mappers.IdentityMapper): self.discr_code = [] self.discr_scope_names_created = set() self.discr_scope_names_copied_to_eval = set() + self.discr_expr_to_var = {} self.eval_code = [] - self.expr_to_var = {} + self.eval_expr_to_var = {} self.assigned_names = set() @@ -1260,8 +1267,10 @@ class OperatorCompiler(mappers.IdentityMapper): return var(expr_name) else: + expr_to_var = codegen_state.get_expr_to_var(self) + try: - return self.expr_to_var[expr.child] + return expr_to_var[expr.child] except KeyError: priority = getattr(expr, "priority", 0) @@ -1271,7 +1280,7 @@ class OperatorCompiler(mappers.IdentityMapper): codegen_state, rec_child, priority=priority, prefix=expr.prefix) - self.expr_to_var[expr.child] = cse_var + expr_to_var[expr.child] = cse_var return cse_var def map_operator_binding(self, expr, codegen_state, name_hint=None): @@ -1308,8 +1317,9 @@ class OperatorCompiler(mappers.IdentityMapper): for par in expr.parameters])) def map_ref_diff_op_binding(self, expr, codegen_state): + expr_to_var = codegen_state.get_expr_to_var(self) try: - return self.expr_to_var[expr] + return expr_to_var[expr] except KeyError: all_diffs = [diff for diff in self.diff_ops @@ -1332,23 +1342,25 @@ class OperatorCompiler(mappers.IdentityMapper): from pymbolic import var for n, d in zip(names, all_diffs): - self.expr_to_var[d] = var(n) + expr_to_var[d] = var(n) - return self.expr_to_var[expr] + return expr_to_var[expr] def map_rank_data_swap_binding(self, expr, codegen_state, name_hint): + expr_to_var = codegen_state.get_expr_to_var(self) + try: - return self.expr_to_var[expr] + return expr_to_var[expr] except KeyError: field = self.rec(expr.field, codegen_state) name = self.name_gen("raw_rank%02d_bdry_data" % expr.op.i_remote_part) field_insn = RankDataSwapAssign(name=name, field=field, op=expr.op) codegen_state.get_code_list(self).append(field_insn) field_var = Variable(field_insn.name) - self.expr_to_var[expr] = self.assign_to_new_var(codegen_state, + expr_to_var[expr] = self.assign_to_new_var(codegen_state, expr.op(field_var), prefix=name_hint) - return self.expr_to_var[expr] + return expr_to_var[expr] # }}} -- GitLab