Skip to content
Snippets Groups Projects
Commit f75c36af authored by Alexandru Fikl's avatar Alexandru Fikl
Browse files

fix map_if handling of scalar conditions

parent 2b7122d0
No related branches found
No related tags found
No related merge requests found
......@@ -182,10 +182,10 @@ class ExecutionMapper(mappers.Evaluator,
def map_if(self, expr):
bool_crit = self.rec(expr.condition)
if isinstance(bool_crit, DOFArray):
if isinstance(bool_crit, DOFArray):
# continues below
pass
elif isinstance(bool_crit, np.number):
elif isinstance(bool_crit, (np.bool_, np.bool, np.number)):
if bool_crit:
return self.rec(expr.then)
else:
......@@ -194,7 +194,7 @@ class ExecutionMapper(mappers.Evaluator,
raise TypeError(
"Expected criterion to be of type np.number or DOFArray")
assert isinstance(bool_crit, DOFArray)
assert isinstance(bool_crit, DOFArray)
ngroups = len(bool_crit)
from pymbolic import var
......@@ -208,7 +208,7 @@ class ExecutionMapper(mappers.Evaluator,
import pymbolic.primitives as p
var = p.Variable
if isinstance(then, DOFArray):
if isinstance(then, DOFArray):
sym_then = var("a")[subscript]
def get_then(igrp):
......@@ -222,12 +222,12 @@ class ExecutionMapper(mappers.Evaluator,
raise TypeError(
"Expected 'then' to be of type np.number or DOFArray")
if isinstance(else_, DOFArray):
if isinstance(else_, DOFArray):
sym_else = var("b")[subscript]
def get_else(igrp):
return else_[igrp]
elif isinstance(else_, np.number):
elif isinstance(else_, np.number):
sym_else = var("b")
def get_else(igrp):
......
......@@ -613,6 +613,8 @@ def test_external_call(ctx_factory):
@pytest.mark.parametrize("array_type", ["scalar", "vector"])
def test_function_symbol_array(ctx_factory, array_type):
"""Test if `FunctionSymbol` distributed properly over object arrays."""
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
actx = PyOpenCLArrayContext(queue)
......@@ -645,6 +647,26 @@ def test_function_symbol_array(ctx_factory, array_type):
assert isinstance(norm, float)
def test_map_if(ctx_factory):
"""Test :meth:`grudge.symbolic.execution.ExecutionMapper.map_if` handling
of scalar conditions.
"""
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
actx = PyOpenCLArrayContext(queue)
from meshmode.mesh.generation import generate_regular_rect_mesh
dim = 2
mesh = generate_regular_rect_mesh(
a=(-0.5,)*dim, b=(0.5,)*dim,
n=(8,)*dim, order=4)
discr = DGDiscretizationWithBoundaries(actx, mesh, order=4)
sym_if = sym.If(sym.Comparison(2.0, "<", 1.0e-14), 1.0, 2.0)
bind(discr, sym_if)(actx)
# You can test individual routines by typing
# $ python test_grudge.py 'test_routine()'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment