diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py index 2318b70fbca1e4d4d0fa3428d491d53e10baae7e..d2f1a01f1c5415e7c2877625156bc0087ab104a6 100644 --- a/grudge/symbolic/primitives.py +++ b/grudge/symbolic/primitives.py @@ -362,6 +362,11 @@ class FunctionSymbol(ExpressionBase, VariableBase): :class:`~pymbolic.primitives.Call`. """ + def __call__(self, *exprs): + from pytools.obj_array import with_object_array_or_scalar_n_args + return with_object_array_or_scalar_n_args( + super(FunctionSymbol, self).__call__, *exprs) + mapper_method = "map_function_symbol" diff --git a/test/test_grudge.py b/test/test_grudge.py index e63cea6f80729a2a5796dfcc4276e94ac3cc8f91..fc20811f9e126dea877fa61d4111b3b336db606f 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -30,7 +30,7 @@ import pyopencl as cl import pyopencl.array import pyopencl.clmath -from pytools.obj_array import join_fields +from pytools.obj_array import join_fields, make_obj_array import pytest # noqa @@ -618,6 +618,36 @@ def test_external_call(ctx_factory): assert (result == 5).get().all() +@pytest.mark.parametrize("array_type", ["scalar", "vector"]) +def test_function_symbol_array(ctx_factory, array_type): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + from meshmode.mesh.generation import generate_regular_rect_mesh + dim = 2 + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dim, b=(0.5,)*dim, + n=(8,)*dim, order=4) + discr = DGDiscretizationWithBoundaries(ctx, mesh, order=4) + nnodes = discr.discr_from_dd(sym.DD_VOLUME).nnodes + + import pyopencl.clrandom # noqa: F401 + if array_type == "scalar": + sym_x = sym.var("x") + x = cl.clrandom.rand(queue, nnodes, dtype=np.float) + elif array_type == "vector": + sym_x = sym.make_sym_array("x", dim) + x = make_obj_array([ + cl.clrandom.rand(queue, nnodes, dtype=np.float) + for _ in range(dim) + ]) + else: + raise ValueError("unknown array type") + + norm = bind(discr, sym.norm(2, sym_x))(queue, x=x) + assert isinstance(norm, float) + + # You can test individual routines by typing # $ python test_grudge.py 'test_routine()'