From 0e42886c32db681dddc1787f669db3e5bb27baa5 Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Sun, 12 May 2019 20:01:05 -0500 Subject: [PATCH] Introduce a FunctionSymbol class --- examples/gas_dynamics/lbm-simple.py | 6 +++--- examples/wave/wiggly.py | 4 ++-- grudge/models/em.py | 2 +- grudge/models/gas_dynamics/__init__.py | 8 ++++---- grudge/models/wave.py | 2 +- grudge/symbolic/mappers/__init__.py | 9 ++++++--- grudge/symbolic/operators.py | 6 +++--- grudge/symbolic/primitives.py | 23 +++++++++++++++-------- test/test_grudge.py | 3 +-- 9 files changed, 36 insertions(+), 27 deletions(-) diff --git a/examples/gas_dynamics/lbm-simple.py b/examples/gas_dynamics/lbm-simple.py index f99d5b76..2d3496da 100644 --- a/examples/gas_dynamics/lbm-simple.py +++ b/examples/gas_dynamics/lbm-simple.py @@ -63,12 +63,12 @@ def main(write_output=True, dtype=np.float32): from grudge.data import CompiledExpressionData def ic_expr(t, x, fields): - from grudge.symbolic import CFunction + from grudge.symbolic import FunctionSymbol from pymbolic.primitives import IfPositive from pytools.obj_array import make_obj_array - tanh = CFunction("tanh") - sin = CFunction("sin") + tanh = FunctionSymbol("tanh") + sin = FunctionSymbol("sin") rho = 1 u0 = 0.05 diff --git a/examples/wave/wiggly.py b/examples/wave/wiggly.py index f7860744..11c01212 100644 --- a/examples/wave/wiggly.py +++ b/examples/wave/wiggly.py @@ -70,8 +70,8 @@ def main(write_output=True, from grudge.models.wave import StrongWaveOperator op = StrongWaveOperator(-1, discr.dimensions, source_f= - sym.CFunction("sin")(source_omega*sym.ScalarParameter("t")) - * sym.CFunction("exp")( + sym.FunctionSymbol("sin")(source_omega*sym.ScalarParameter("t")) + * sym.FunctionSymbol("exp")( -np.dot(sym_source_center_dist, sym_source_center_dist) / source_width**2), dirichlet_tag="boundary", diff --git a/grudge/models/em.py b/grudge/models/em.py index f0e44f90..bf7495e2 100644 --- a/grudge/models/em.py +++ b/grudge/models/em.py @@ -355,7 +355,7 @@ class MaxwellOperator(HyperbolicOperator): return 1/sqrt(self.epsilon*self.mu) # a number else: import grudge.symbolic as sym - return sym.NodalMax()(1/sym.CFunction("sqrt")(self.epsilon*self.mu)) + return sym.NodalMax()(1/sym.FunctionSymbol("sqrt")(self.epsilon*self.mu)) def max_eigenvalue(self, t, fields=None, discr=None, context={}): if self.fixed_material: diff --git a/grudge/models/gas_dynamics/__init__.py b/grudge/models/gas_dynamics/__init__.py index e5a8ddc9..4df408d8 100644 --- a/grudge/models/gas_dynamics/__init__.py +++ b/grudge/models/gas_dynamics/__init__.py @@ -326,8 +326,8 @@ class GasDynamicsOperator(TimeDependentOperator): def characteristic_velocity_optemplate(self, state): from grudge.symbolic.operators import ElementwiseMaxOperator - from grudge.symbolic.primitives import CFunction - sqrt = CFunction("sqrt") + from grudge.symbolic.primitives import FunctionSymbol + sqrt = FunctionSymbol("sqrt") sound_speed = cse(sqrt( self.equation_of_state.gamma*self.cse_p(state)/self.cse_rho(state)), @@ -743,8 +743,8 @@ class GasDynamicsOperator(TimeDependentOperator): volq_flux = self.flux(self.volq_state()) faceq_flux = self.flux(self.faceq_state()) - from grudge.symbolic.primitives import CFunction - sqrt = CFunction("sqrt") + from grudge.symbolic.primitives import FunctionSymbol + sqrt = FunctionSymbol("sqrt") speed = self.characteristic_velocity_optemplate(self.state()) diff --git a/grudge/models/wave.py b/grudge/models/wave.py index 40710eb2..0ac60a33 100644 --- a/grudge/models/wave.py +++ b/grudge/models/wave.py @@ -459,7 +459,7 @@ class VariableCoefficientWeakWaveOperator(HyperbolicOperator): self.radiation_tag]) def max_eigenvalue(self, t, fields=None, discr=None): - return sym.NodalMax()(sym.CFunction("fabs")(self.c)) + return sym.NodalMax()(sym.FunctionSymbol("fabs")(self.c)) # }}} diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index bfb60879..304bfb4c 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -211,6 +211,7 @@ class IdentityMapperMixin(LocalOpReducerMixin, FluxOpReducerMixin): # it's a leaf--no changing children return expr + map_function_symbol = map_grudge_variable map_ones = map_grudge_variable map_node_coordinate_component = map_grudge_variable @@ -277,7 +278,7 @@ class FlopCounter( def map_grudge_variable(self, expr): return 0 - def map_call(self, expr): + def map_function_symbol(self, expr): return 1 def map_ones(self, expr): @@ -842,6 +843,9 @@ class StringifyMapper(pymbolic.mapper.stringifier.StringifyMapper): def map_grudge_variable(self, expr, enclosing_prec): return "%s:%s" % (expr.name, self._format_dd(expr.dd)) + def map_function_symbol(self, expr, enclosing_prec): + return expr + def map_interpolation(self, expr, enclosing_prec): return "Interp" + self._format_op_dd(expr) @@ -1227,8 +1231,7 @@ class CollectorMixin(OperatorReducerMixin, LocalOpReducerMixin, FluxOpReducerMix return OrderedSet() map_grudge_variable = map_constant - # Found in function call nodes - map_variable = map_grudge_variable + map_function_symbol = map_constant map_ones = map_grudge_variable map_node_coordinate_component = map_grudge_variable diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py index 53fb1422..a1899ed2 100644 --- a/grudge/symbolic/operators.py +++ b/grudge/symbolic/operators.py @@ -602,16 +602,16 @@ def norm(p, arg, dd=None): if p == 2: norm_squared = sym.NodalSum(dd_in=dd)( - sym.CFunction("fabs")( + sym.FunctionSymbol("fabs")( arg * sym.MassOperator()(arg))) if isinstance(norm_squared, np.ndarray): norm_squared = norm_squared.sum() - return sym.CFunction("sqrt")(norm_squared) + return sym.FunctionSymbol("sqrt")(norm_squared) elif p == np.Inf: - result = sym.NodalMax(dd_in=dd)(sym.CFunction("fabs")(arg)) + result = sym.NodalMax(dd_in=dd)(sym.FunctionSymbol("fabs")(arg)) from pymbolic.primitives import Max if isinstance(result, np.ndarray): diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py index 3c2dc96e..c59f1518 100644 --- a/grudge/symbolic/primitives.py +++ b/grudge/symbolic/primitives.py @@ -354,14 +354,21 @@ def make_sym_mv(name, dim, var_factory=None): make_sym_array(name, dim, var_factory)) -# function symbols -CFunction = Variable -sqrt = Variable("sqrt") -exp = Variable("exp") -sin = Variable("sin") -cos = Variable("cos") -bessel_j = Variable("bessel_j") -bessel_y = Variable("bessel_y") +class FunctionSymbol(ExpressionBase, pymbolic.primitives.Variable): + """A symbol to be used as the function argument of + :class:`pymbolic.primitives.Call`. + + """ + + mapper_method = "map_function_symbol" + + +sqrt = FunctionSymbol("sqrt") +exp = FunctionSymbol("exp") +sin = FunctionSymbol("sin") +cos = FunctionSymbol("cos") +bessel_j = FunctionSymbol("bessel_j") +bessel_y = FunctionSymbol("bessel_y") # }}} diff --git a/test/test_grudge.py b/test/test_grudge.py index f646dff4..288922e0 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -554,10 +554,9 @@ def test_external_call(ctx_factory): discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=1) ones = sym.Ones(sym.DD_VOLUME) - from pymbolic.primitives import Variable op = ( ones * 3 - + Variable("double")(ones)) + + sym.FunctionSymbol("double")(ones)) from grudge.function_registry import ( base_function_registry, register_external_function) -- GitLab