diff --git a/grudge/function_registry.py b/grudge/function_registry.py index 2f56eba4ad26d066039f5ac373619ac4ac10a656..4d521560d8e6c83a8e9fa238b3c46955dc62e7a7 100644 --- a/grudge/function_registry.py +++ b/grudge/function_registry.py @@ -136,6 +136,7 @@ class CElementwiseBinaryFunction(Function): return func(arg0, arg1) from pymbolic.primitives import Variable + @memoize_in(self, "map_call_knl_%s" % func_name) def knl(): i = Variable("i") diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py index c1360b1627a79ccafe3e52405a099f550012a410..9b6fd75806d6ed69c6188ca8bd13f184f218e955 100644 --- a/grudge/symbolic/operators.py +++ b/grudge/symbolic/operators.py @@ -146,15 +146,23 @@ class ElementwiseLinearOperator(Operator): class InterpolationOperator(Operator): def __init__(self, dd_in, dd_out): - import grudge.symbolic.primitives as prim - official_dd_in = prim.as_dofdesc(dd_in) - official_dd_out = prim.as_dofdesc(dd_out) + super(InterpolationOperator, self).__init__(dd_in, dd_out) - if official_dd_in == official_dd_out: - raise ValueError("Interpolating from {} to {}" - " does not do anything.".format(official_dd_in, official_dd_out)) + def __call__(self, expr): + from pytools.obj_array import with_object_array_or_scalar - super(InterpolationOperator, self).__init__(dd_in, dd_out) + def interp_one(subexpr): + from pymbolic.primitives import is_constant + if self.dd_in == self.dd_out: + # no-op interpolation, go away + return subexpr + elif is_constant(subexpr): + return subexpr + else: + from grudge.symbolic.primitives import OperatorBinding + return OperatorBinding(self, subexpr) + + return with_object_array_or_scalar(interp_one, expr) mapper_method = intern("map_interpolation")