From f75c36af4e72035fe01b652e1453413d6f3a787d Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Sun, 16 Aug 2020 20:17:02 -0500 Subject: [PATCH] fix map_if handling of scalar conditions --- grudge/execution.py | 12 ++++++------ test/test_grudge.py | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/grudge/execution.py b/grudge/execution.py index 88c0d318..0f6711e8 100644 --- a/grudge/execution.py +++ b/grudge/execution.py @@ -182,10 +182,10 @@ class ExecutionMapper(mappers.Evaluator, def map_if(self, expr): bool_crit = self.rec(expr.condition) - if isinstance(bool_crit, DOFArray): + if isinstance(bool_crit, DOFArray): # continues below pass - elif isinstance(bool_crit, np.number): + elif isinstance(bool_crit, (np.bool_, np.bool, np.number)): if bool_crit: return self.rec(expr.then) else: @@ -194,7 +194,7 @@ class ExecutionMapper(mappers.Evaluator, raise TypeError( "Expected criterion to be of type np.number or DOFArray") - assert isinstance(bool_crit, DOFArray) + assert isinstance(bool_crit, DOFArray) ngroups = len(bool_crit) from pymbolic import var @@ -208,7 +208,7 @@ class ExecutionMapper(mappers.Evaluator, import pymbolic.primitives as p var = p.Variable - if isinstance(then, DOFArray): + if isinstance(then, DOFArray): sym_then = var("a")[subscript] def get_then(igrp): @@ -222,12 +222,12 @@ class ExecutionMapper(mappers.Evaluator, raise TypeError( "Expected 'then' to be of type np.number or DOFArray") - if isinstance(else_, DOFArray): + if isinstance(else_, DOFArray): sym_else = var("b")[subscript] def get_else(igrp): return else_[igrp] - elif isinstance(else_, np.number): + elif isinstance(else_, np.number): sym_else = var("b") def get_else(igrp): diff --git a/test/test_grudge.py b/test/test_grudge.py index b6d82a3f..2d29c9d7 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -613,6 +613,8 @@ def test_external_call(ctx_factory): @pytest.mark.parametrize("array_type", ["scalar", "vector"]) def test_function_symbol_array(ctx_factory, array_type): + """Test if `FunctionSymbol` distributed properly over object arrays.""" + ctx = ctx_factory() queue = cl.CommandQueue(ctx) actx = PyOpenCLArrayContext(queue) @@ -645,6 +647,26 @@ def test_function_symbol_array(ctx_factory, array_type): assert isinstance(norm, float) +def test_map_if(ctx_factory): + """Test :meth:`grudge.symbolic.execution.ExecutionMapper.map_if` handling + of scalar conditions. + """ + + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + actx = PyOpenCLArrayContext(queue) + + from meshmode.mesh.generation import generate_regular_rect_mesh + dim = 2 + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dim, b=(0.5,)*dim, + n=(8,)*dim, order=4) + discr = DGDiscretizationWithBoundaries(actx, mesh, order=4) + + sym_if = sym.If(sym.Comparison(2.0, "<", 1.0e-14), 1.0, 2.0) + bind(discr, sym_if)(actx) + + # You can test individual routines by typing # $ python test_grudge.py 'test_routine()' -- GitLab