From 0e42886c32db681dddc1787f669db3e5bb27baa5 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Sun, 12 May 2019 20:01:05 -0500
Subject: [PATCH] Introduce a FunctionSymbol class

---
 examples/gas_dynamics/lbm-simple.py    |  6 +++---
 examples/wave/wiggly.py                |  4 ++--
 grudge/models/em.py                    |  2 +-
 grudge/models/gas_dynamics/__init__.py |  8 ++++----
 grudge/models/wave.py                  |  2 +-
 grudge/symbolic/mappers/__init__.py    |  9 ++++++---
 grudge/symbolic/operators.py           |  6 +++---
 grudge/symbolic/primitives.py          | 23 +++++++++++++++--------
 test/test_grudge.py                    |  3 +--
 9 files changed, 36 insertions(+), 27 deletions(-)

diff --git a/examples/gas_dynamics/lbm-simple.py b/examples/gas_dynamics/lbm-simple.py
index f99d5b76..2d3496da 100644
--- a/examples/gas_dynamics/lbm-simple.py
+++ b/examples/gas_dynamics/lbm-simple.py
@@ -63,12 +63,12 @@ def main(write_output=True, dtype=np.float32):
 
     from grudge.data import CompiledExpressionData
     def ic_expr(t, x, fields):
-        from grudge.symbolic import CFunction
+        from grudge.symbolic import FunctionSymbol
         from pymbolic.primitives import IfPositive
         from pytools.obj_array import make_obj_array
 
-        tanh = CFunction("tanh")
-        sin = CFunction("sin")
+        tanh = FunctionSymbol("tanh")
+        sin = FunctionSymbol("sin")
 
         rho = 1
         u0 = 0.05
diff --git a/examples/wave/wiggly.py b/examples/wave/wiggly.py
index f7860744..11c01212 100644
--- a/examples/wave/wiggly.py
+++ b/examples/wave/wiggly.py
@@ -70,8 +70,8 @@ def main(write_output=True,
     from grudge.models.wave import StrongWaveOperator
     op = StrongWaveOperator(-1, discr.dimensions,
             source_f=
-            sym.CFunction("sin")(source_omega*sym.ScalarParameter("t"))
-            * sym.CFunction("exp")(
+            sym.FunctionSymbol("sin")(source_omega*sym.ScalarParameter("t"))
+            * sym.FunctionSymbol("exp")(
                 -np.dot(sym_source_center_dist, sym_source_center_dist)
                 / source_width**2),
             dirichlet_tag="boundary",
diff --git a/grudge/models/em.py b/grudge/models/em.py
index f0e44f90..bf7495e2 100644
--- a/grudge/models/em.py
+++ b/grudge/models/em.py
@@ -355,7 +355,7 @@ class MaxwellOperator(HyperbolicOperator):
             return 1/sqrt(self.epsilon*self.mu)  # a number
         else:
             import grudge.symbolic as sym
-            return sym.NodalMax()(1/sym.CFunction("sqrt")(self.epsilon*self.mu))
+            return sym.NodalMax()(1/sym.FunctionSymbol("sqrt")(self.epsilon*self.mu))
 
     def max_eigenvalue(self, t, fields=None, discr=None, context={}):
         if self.fixed_material:
diff --git a/grudge/models/gas_dynamics/__init__.py b/grudge/models/gas_dynamics/__init__.py
index e5a8ddc9..4df408d8 100644
--- a/grudge/models/gas_dynamics/__init__.py
+++ b/grudge/models/gas_dynamics/__init__.py
@@ -326,8 +326,8 @@ class GasDynamicsOperator(TimeDependentOperator):
     def characteristic_velocity_optemplate(self, state):
         from grudge.symbolic.operators import ElementwiseMaxOperator
 
-        from grudge.symbolic.primitives import CFunction
-        sqrt = CFunction("sqrt")
+        from grudge.symbolic.primitives import FunctionSymbol
+        sqrt = FunctionSymbol("sqrt")
 
         sound_speed = cse(sqrt(
             self.equation_of_state.gamma*self.cse_p(state)/self.cse_rho(state)),
@@ -743,8 +743,8 @@ class GasDynamicsOperator(TimeDependentOperator):
         volq_flux = self.flux(self.volq_state())
         faceq_flux = self.flux(self.faceq_state())
 
-        from grudge.symbolic.primitives import CFunction
-        sqrt = CFunction("sqrt")
+        from grudge.symbolic.primitives import FunctionSymbol
+        sqrt = FunctionSymbol("sqrt")
 
         speed = self.characteristic_velocity_optemplate(self.state())
 
diff --git a/grudge/models/wave.py b/grudge/models/wave.py
index 40710eb2..0ac60a33 100644
--- a/grudge/models/wave.py
+++ b/grudge/models/wave.py
@@ -459,7 +459,7 @@ class VariableCoefficientWeakWaveOperator(HyperbolicOperator):
             self.radiation_tag])
 
     def max_eigenvalue(self, t, fields=None, discr=None):
-        return sym.NodalMax()(sym.CFunction("fabs")(self.c))
+        return sym.NodalMax()(sym.FunctionSymbol("fabs")(self.c))
 
 # }}}
 
diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py
index bfb60879..304bfb4c 100644
--- a/grudge/symbolic/mappers/__init__.py
+++ b/grudge/symbolic/mappers/__init__.py
@@ -211,6 +211,7 @@ class IdentityMapperMixin(LocalOpReducerMixin, FluxOpReducerMixin):
         # it's a leaf--no changing children
         return expr
 
+    map_function_symbol = map_grudge_variable
     map_ones = map_grudge_variable
     map_node_coordinate_component = map_grudge_variable
 
@@ -277,7 +278,7 @@ class FlopCounter(
     def map_grudge_variable(self, expr):
         return 0
 
-    def map_call(self, expr):
+    def map_function_symbol(self, expr):
         return 1
 
     def map_ones(self, expr):
@@ -842,6 +843,9 @@ class StringifyMapper(pymbolic.mapper.stringifier.StringifyMapper):
     def map_grudge_variable(self, expr, enclosing_prec):
         return "%s:%s" % (expr.name, self._format_dd(expr.dd))
 
+    def map_function_symbol(self, expr, enclosing_prec):
+        return expr
+
     def map_interpolation(self, expr, enclosing_prec):
         return "Interp" + self._format_op_dd(expr)
 
@@ -1227,8 +1231,7 @@ class CollectorMixin(OperatorReducerMixin, LocalOpReducerMixin, FluxOpReducerMix
         return OrderedSet()
 
     map_grudge_variable = map_constant
-    # Found in function call nodes
-    map_variable = map_grudge_variable
+    map_function_symbol = map_constant
 
     map_ones = map_grudge_variable
     map_node_coordinate_component = map_grudge_variable
diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py
index 53fb1422..a1899ed2 100644
--- a/grudge/symbolic/operators.py
+++ b/grudge/symbolic/operators.py
@@ -602,16 +602,16 @@ def norm(p, arg, dd=None):
 
     if p == 2:
         norm_squared = sym.NodalSum(dd_in=dd)(
-                sym.CFunction("fabs")(
+                sym.FunctionSymbol("fabs")(
                     arg * sym.MassOperator()(arg)))
 
         if isinstance(norm_squared, np.ndarray):
             norm_squared = norm_squared.sum()
 
-        return sym.CFunction("sqrt")(norm_squared)
+        return sym.FunctionSymbol("sqrt")(norm_squared)
 
     elif p == np.Inf:
-        result = sym.NodalMax(dd_in=dd)(sym.CFunction("fabs")(arg))
+        result = sym.NodalMax(dd_in=dd)(sym.FunctionSymbol("fabs")(arg))
         from pymbolic.primitives import Max
 
         if isinstance(result, np.ndarray):
diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py
index 3c2dc96e..c59f1518 100644
--- a/grudge/symbolic/primitives.py
+++ b/grudge/symbolic/primitives.py
@@ -354,14 +354,21 @@ def make_sym_mv(name, dim, var_factory=None):
             make_sym_array(name, dim, var_factory))
 
 
-# function symbols
-CFunction = Variable
-sqrt = Variable("sqrt")
-exp = Variable("exp")
-sin = Variable("sin")
-cos = Variable("cos")
-bessel_j = Variable("bessel_j")
-bessel_y = Variable("bessel_y")
+class FunctionSymbol(ExpressionBase, pymbolic.primitives.Variable):
+    """A symbol to be used as the function argument of
+    :class:`pymbolic.primitives.Call`.
+
+    """
+
+    mapper_method = "map_function_symbol"
+
+
+sqrt = FunctionSymbol("sqrt")
+exp = FunctionSymbol("exp")
+sin = FunctionSymbol("sin")
+cos = FunctionSymbol("cos")
+bessel_j = FunctionSymbol("bessel_j")
+bessel_y = FunctionSymbol("bessel_y")
 
 # }}}
 
diff --git a/test/test_grudge.py b/test/test_grudge.py
index f646dff4..288922e0 100644
--- a/test/test_grudge.py
+++ b/test/test_grudge.py
@@ -554,10 +554,9 @@ def test_external_call(ctx_factory):
     discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=1)
 
     ones = sym.Ones(sym.DD_VOLUME)
-    from pymbolic.primitives import Variable
     op = (
             ones * 3
-            + Variable("double")(ones))
+            + sym.FunctionSymbol("double")(ones))
 
     from grudge.function_registry import (
             base_function_registry, register_external_function)
-- 
GitLab