diff --git a/examples/gas_dynamics/lbm-simple.py b/examples/gas_dynamics/lbm-simple.py index f99d5b76994b46f079b41af9f74bf69789ae18f5..2d3496da340fbda01adeb33b6a59290b76b1c698 100644 --- a/examples/gas_dynamics/lbm-simple.py +++ b/examples/gas_dynamics/lbm-simple.py @@ -63,12 +63,12 @@ def main(write_output=True, dtype=np.float32): from grudge.data import CompiledExpressionData def ic_expr(t, x, fields): - from grudge.symbolic import CFunction + from grudge.symbolic import FunctionSymbol from pymbolic.primitives import IfPositive from pytools.obj_array import make_obj_array - tanh = CFunction("tanh") - sin = CFunction("sin") + tanh = FunctionSymbol("tanh") + sin = FunctionSymbol("sin") rho = 1 u0 = 0.05 diff --git a/examples/wave/wiggly.py b/examples/wave/wiggly.py index f786074443d2123882bcaecf85972777b6694622..11c0121214982b70fb891b443dfec267ce5defb6 100644 --- a/examples/wave/wiggly.py +++ b/examples/wave/wiggly.py @@ -70,8 +70,8 @@ def main(write_output=True, from grudge.models.wave import StrongWaveOperator op = StrongWaveOperator(-1, discr.dimensions, source_f= - sym.CFunction("sin")(source_omega*sym.ScalarParameter("t")) - * sym.CFunction("exp")( + sym.FunctionSymbol("sin")(source_omega*sym.ScalarParameter("t")) + * sym.FunctionSymbol("exp")( -np.dot(sym_source_center_dist, sym_source_center_dist) / source_width**2), dirichlet_tag="boundary", diff --git a/grudge/models/em.py b/grudge/models/em.py index f0e44f903d226e28710709651e948293a534d1b8..bf7495e2c9447b54d37d49a78905e549d9df54d8 100644 --- a/grudge/models/em.py +++ b/grudge/models/em.py @@ -355,7 +355,7 @@ class MaxwellOperator(HyperbolicOperator): return 1/sqrt(self.epsilon*self.mu) # a number else: import grudge.symbolic as sym - return sym.NodalMax()(1/sym.CFunction("sqrt")(self.epsilon*self.mu)) + return sym.NodalMax()(1/sym.FunctionSymbol("sqrt")(self.epsilon*self.mu)) def max_eigenvalue(self, t, fields=None, discr=None, context={}): if self.fixed_material: diff --git a/grudge/models/gas_dynamics/__init__.py b/grudge/models/gas_dynamics/__init__.py index e5a8ddc9418f1bac4776f252d2951fab3d3c1a2b..4df408d811a826a1f924b855f0ff8a9fce8a29f2 100644 --- a/grudge/models/gas_dynamics/__init__.py +++ b/grudge/models/gas_dynamics/__init__.py @@ -326,8 +326,8 @@ class GasDynamicsOperator(TimeDependentOperator): def characteristic_velocity_optemplate(self, state): from grudge.symbolic.operators import ElementwiseMaxOperator - from grudge.symbolic.primitives import CFunction - sqrt = CFunction("sqrt") + from grudge.symbolic.primitives import FunctionSymbol + sqrt = FunctionSymbol("sqrt") sound_speed = cse(sqrt( self.equation_of_state.gamma*self.cse_p(state)/self.cse_rho(state)), @@ -743,8 +743,8 @@ class GasDynamicsOperator(TimeDependentOperator): volq_flux = self.flux(self.volq_state()) faceq_flux = self.flux(self.faceq_state()) - from grudge.symbolic.primitives import CFunction - sqrt = CFunction("sqrt") + from grudge.symbolic.primitives import FunctionSymbol + sqrt = FunctionSymbol("sqrt") speed = self.characteristic_velocity_optemplate(self.state()) diff --git a/grudge/models/wave.py b/grudge/models/wave.py index 40710eb2363b7e475f1be299f1b69540343309be..0ac60a336a4f72f12ba4f3eeb0bc2d79c9f07509 100644 --- a/grudge/models/wave.py +++ b/grudge/models/wave.py @@ -459,7 +459,7 @@ class VariableCoefficientWeakWaveOperator(HyperbolicOperator): self.radiation_tag]) def max_eigenvalue(self, t, fields=None, discr=None): - return sym.NodalMax()(sym.CFunction("fabs")(self.c)) + return sym.NodalMax()(sym.FunctionSymbol("fabs")(self.c)) # }}} diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index bfb60879c911918f901f460f197d4004a38068ae..304bfb4c66db61a5a1f539137179962a9717251d 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -211,6 +211,7 @@ class IdentityMapperMixin(LocalOpReducerMixin, FluxOpReducerMixin): # it's a leaf--no changing children return expr + map_function_symbol = map_grudge_variable map_ones = map_grudge_variable map_node_coordinate_component = map_grudge_variable @@ -277,7 +278,7 @@ class FlopCounter( def map_grudge_variable(self, expr): return 0 - def map_call(self, expr): + def map_function_symbol(self, expr): return 1 def map_ones(self, expr): @@ -842,6 +843,9 @@ class StringifyMapper(pymbolic.mapper.stringifier.StringifyMapper): def map_grudge_variable(self, expr, enclosing_prec): return "%s:%s" % (expr.name, self._format_dd(expr.dd)) + def map_function_symbol(self, expr, enclosing_prec): + return expr + def map_interpolation(self, expr, enclosing_prec): return "Interp" + self._format_op_dd(expr) @@ -1227,8 +1231,7 @@ class CollectorMixin(OperatorReducerMixin, LocalOpReducerMixin, FluxOpReducerMix return OrderedSet() map_grudge_variable = map_constant - # Found in function call nodes - map_variable = map_grudge_variable + map_function_symbol = map_constant map_ones = map_grudge_variable map_node_coordinate_component = map_grudge_variable diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py index 53fb1422619db271786abb9c2c17df5b3f9a97c0..a1899ed21ef6431f7b57a81fcf143680433b7b53 100644 --- a/grudge/symbolic/operators.py +++ b/grudge/symbolic/operators.py @@ -602,16 +602,16 @@ def norm(p, arg, dd=None): if p == 2: norm_squared = sym.NodalSum(dd_in=dd)( - sym.CFunction("fabs")( + sym.FunctionSymbol("fabs")( arg * sym.MassOperator()(arg))) if isinstance(norm_squared, np.ndarray): norm_squared = norm_squared.sum() - return sym.CFunction("sqrt")(norm_squared) + return sym.FunctionSymbol("sqrt")(norm_squared) elif p == np.Inf: - result = sym.NodalMax(dd_in=dd)(sym.CFunction("fabs")(arg)) + result = sym.NodalMax(dd_in=dd)(sym.FunctionSymbol("fabs")(arg)) from pymbolic.primitives import Max if isinstance(result, np.ndarray): diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py index 3c2dc96ef25af3254b9bac3e7399d9567223c97d..c59f1518b5034a9b69c3aca4a22c0c4bab92d7a6 100644 --- a/grudge/symbolic/primitives.py +++ b/grudge/symbolic/primitives.py @@ -354,14 +354,21 @@ def make_sym_mv(name, dim, var_factory=None): make_sym_array(name, dim, var_factory)) -# function symbols -CFunction = Variable -sqrt = Variable("sqrt") -exp = Variable("exp") -sin = Variable("sin") -cos = Variable("cos") -bessel_j = Variable("bessel_j") -bessel_y = Variable("bessel_y") +class FunctionSymbol(ExpressionBase, pymbolic.primitives.Variable): + """A symbol to be used as the function argument of + :class:`pymbolic.primitives.Call`. + + """ + + mapper_method = "map_function_symbol" + + +sqrt = FunctionSymbol("sqrt") +exp = FunctionSymbol("exp") +sin = FunctionSymbol("sin") +cos = FunctionSymbol("cos") +bessel_j = FunctionSymbol("bessel_j") +bessel_y = FunctionSymbol("bessel_y") # }}} diff --git a/test/test_grudge.py b/test/test_grudge.py index f646dff4cb8b1c54377b9580cc1baf6001e4c852..288922e07529aa70bffbc1f812efe582e146ff60 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -554,10 +554,9 @@ def test_external_call(ctx_factory): discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=1) ones = sym.Ones(sym.DD_VOLUME) - from pymbolic.primitives import Variable op = ( ones * 3 - + Variable("double")(ones)) + + sym.FunctionSymbol("double")(ones)) from grudge.function_registry import ( base_function_registry, register_external_function)