From d3b3d04e76b60531f2c28398e47fabd5487c3530 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Tue, 26 May 2020 11:10:35 -0500 Subject: [PATCH] distribute FunctionSymbol over object arrays --- grudge/symbolic/primitives.py | 5 +++++ test/test_grudge.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py index 2318b70f..d2f1a01f 100644 --- a/grudge/symbolic/primitives.py +++ b/grudge/symbolic/primitives.py @@ -362,6 +362,11 @@ class FunctionSymbol(ExpressionBase, VariableBase): :class:`~pymbolic.primitives.Call`. """ + def __call__(self, *exprs): + from pytools.obj_array import with_object_array_or_scalar_n_args + return with_object_array_or_scalar_n_args( + super(FunctionSymbol, self).__call__, *exprs) + mapper_method = "map_function_symbol" diff --git a/test/test_grudge.py b/test/test_grudge.py index e63cea6f..fc20811f 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -30,7 +30,7 @@ import pyopencl as cl import pyopencl.array import pyopencl.clmath -from pytools.obj_array import join_fields +from pytools.obj_array import join_fields, make_obj_array import pytest # noqa @@ -618,6 +618,36 @@ def test_external_call(ctx_factory): assert (result == 5).get().all() +@pytest.mark.parametrize("array_type", ["scalar", "vector"]) +def test_function_symbol_array(ctx_factory, array_type): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + from meshmode.mesh.generation import generate_regular_rect_mesh + dim = 2 + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dim, b=(0.5,)*dim, + n=(8,)*dim, order=4) + discr = DGDiscretizationWithBoundaries(ctx, mesh, order=4) + nnodes = discr.discr_from_dd(sym.DD_VOLUME).nnodes + + import pyopencl.clrandom # noqa: F401 + if array_type == "scalar": + sym_x = sym.var("x") + x = cl.clrandom.rand(queue, nnodes, dtype=np.float) + elif array_type == "vector": + sym_x = sym.make_sym_array("x", dim) + x = make_obj_array([ + cl.clrandom.rand(queue, nnodes, dtype=np.float) + for _ in range(dim) + ]) + else: + raise ValueError("unknown array type") + + norm = bind(discr, sym.norm(2, sym_x))(queue, x=x) + assert isinstance(norm, float) + + # You can test individual routines by typing # $ python test_grudge.py 'test_routine()' -- GitLab