diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index 32ad7b8c87419ef21c686247131017ee984f6da7..9c08abb05cf113371ec0fe6ca071a39b74818e37 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -492,16 +492,18 @@ class GlobalToReferenceMapper(CSECachingMapperMixin, IdentityMapper): jac_notag = sym.area_element(self.ambient_dim, self.dim, where=expr.op.where, quadrature_tag=None) - def rewrite_derivative(ref_class, field, quadrature_tag, where, - with_jacobian=True): + def rewrite_derivative(ref_class, field, where, + input_quadrature_tag=None, with_jacobian=True): jac_tag = sym.area_element(self.ambient_dim, self.dim, - where=expr.op.where, quadrature_tag=quadrature_tag) + where=expr.op.where, quadrature_tag=input_quadrature_tag) - if quadrature_tag is not None: - diff_kwargs = dict(quadrature_tag=quadrature_tag) + if input_quadrature_tag is not None: + diff_kwargs = dict(input_quadrature_tag=input_quadrature_tag) else: diff_kwargs = {} + diff_kwargs["where"] = where + rec_field = self.rec(field) if with_jacobian: rec_field = jac_tag * rec_field @@ -510,7 +512,7 @@ class GlobalToReferenceMapper(CSECachingMapperMixin, IdentityMapper): rst_axis, expr.op.xyz_axis, ambient_dim=self.ambient_dim, dim=self.dim) * ref_class(rst_axis, **diff_kwargs)(rec_field) - for rst_axis in range(self.dimensions)) + for rst_axis in range(self.dim)) if isinstance(expr.op, op.MassOperator): return op.ReferenceMassOperator()( @@ -546,7 +548,7 @@ class GlobalToReferenceMapper(CSECachingMapperMixin, IdentityMapper): elif isinstance(expr.op, op.QuadratureStiffnessTOperator): return rewrite_derivative( op.ReferenceQuadratureStiffnessTOperator, - expr.field, quadrature_tag=expr.op.quadrature_tag, + expr.field, input_quadrature_tag=expr.op.input_quadrature_tag, where=expr.op.where) elif isinstance(expr.op, op.MInvSTOperator):