From a43b2e2bb3ca4fc7e7204c1a7402e8bdd499540f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 20 Oct 2015 11:13:09 -0500 Subject: [PATCH] Fix global -> ref rewriting for derivatives --- grudge/symbolic/mappers/__init__.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index 32ad7b8c..9c08abb0 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -492,16 +492,18 @@ class GlobalToReferenceMapper(CSECachingMapperMixin, IdentityMapper): jac_notag = sym.area_element(self.ambient_dim, self.dim, where=expr.op.where, quadrature_tag=None) - def rewrite_derivative(ref_class, field, quadrature_tag, where, - with_jacobian=True): + def rewrite_derivative(ref_class, field, where, + input_quadrature_tag=None, with_jacobian=True): jac_tag = sym.area_element(self.ambient_dim, self.dim, - where=expr.op.where, quadrature_tag=quadrature_tag) + where=expr.op.where, quadrature_tag=input_quadrature_tag) - if quadrature_tag is not None: - diff_kwargs = dict(quadrature_tag=quadrature_tag) + if input_quadrature_tag is not None: + diff_kwargs = dict(input_quadrature_tag=input_quadrature_tag) else: diff_kwargs = {} + diff_kwargs["where"] = where + rec_field = self.rec(field) if with_jacobian: rec_field = jac_tag * rec_field @@ -510,7 +512,7 @@ class GlobalToReferenceMapper(CSECachingMapperMixin, IdentityMapper): rst_axis, expr.op.xyz_axis, ambient_dim=self.ambient_dim, dim=self.dim) * ref_class(rst_axis, **diff_kwargs)(rec_field) - for rst_axis in range(self.dimensions)) + for rst_axis in range(self.dim)) if isinstance(expr.op, op.MassOperator): return op.ReferenceMassOperator()( @@ -546,7 +548,7 @@ class GlobalToReferenceMapper(CSECachingMapperMixin, IdentityMapper): elif isinstance(expr.op, op.QuadratureStiffnessTOperator): return rewrite_derivative( op.ReferenceQuadratureStiffnessTOperator, - expr.field, quadrature_tag=expr.op.quadrature_tag, + expr.field, input_quadrature_tag=expr.op.input_quadrature_tag, where=expr.op.where) elif isinstance(expr.op, op.MInvSTOperator): -- GitLab