From d3b3d04e76b60531f2c28398e47fabd5487c3530 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Tue, 26 May 2020 11:10:35 -0500
Subject: [PATCH] distribute FunctionSymbol over object arrays

---
 grudge/symbolic/primitives.py |  5 +++++
 test/test_grudge.py           | 32 +++++++++++++++++++++++++++++++-
 2 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py
index 2318b70f..d2f1a01f 100644
--- a/grudge/symbolic/primitives.py
+++ b/grudge/symbolic/primitives.py
@@ -362,6 +362,11 @@ class FunctionSymbol(ExpressionBase, VariableBase):
     :class:`~pymbolic.primitives.Call`.
     """
 
+    def __call__(self, *exprs):
+        from pytools.obj_array import with_object_array_or_scalar_n_args
+        return with_object_array_or_scalar_n_args(
+                super(FunctionSymbol, self).__call__, *exprs)
+
     mapper_method = "map_function_symbol"
 
 
diff --git a/test/test_grudge.py b/test/test_grudge.py
index e63cea6f..fc20811f 100644
--- a/test/test_grudge.py
+++ b/test/test_grudge.py
@@ -30,7 +30,7 @@ import pyopencl as cl
 import pyopencl.array
 import pyopencl.clmath
 
-from pytools.obj_array import join_fields
+from pytools.obj_array import join_fields, make_obj_array
 
 import pytest  # noqa
 
@@ -618,6 +618,36 @@ def test_external_call(ctx_factory):
     assert (result == 5).get().all()
 
 
+@pytest.mark.parametrize("array_type", ["scalar", "vector"])
+def test_function_symbol_array(ctx_factory, array_type):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    from meshmode.mesh.generation import generate_regular_rect_mesh
+    dim = 2
+    mesh = generate_regular_rect_mesh(
+            a=(-0.5,)*dim, b=(0.5,)*dim,
+            n=(8,)*dim, order=4)
+    discr = DGDiscretizationWithBoundaries(ctx, mesh, order=4)
+    nnodes = discr.discr_from_dd(sym.DD_VOLUME).nnodes
+
+    import pyopencl.clrandom        # noqa: F401
+    if array_type == "scalar":
+        sym_x = sym.var("x")
+        x = cl.clrandom.rand(queue, nnodes, dtype=np.float)
+    elif array_type == "vector":
+        sym_x = sym.make_sym_array("x", dim)
+        x = make_obj_array([
+            cl.clrandom.rand(queue, nnodes, dtype=np.float)
+            for _ in range(dim)
+            ])
+    else:
+        raise ValueError("unknown array type")
+
+    norm = bind(discr, sym.norm(2, sym_x))(queue, x=x)
+    assert isinstance(norm, float)
+
+
 # You can test individual routines by typing
 # $ python test_grudge.py 'test_routine()'
 
-- 
GitLab