From f75c36af4e72035fe01b652e1453413d6f3a787d Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sun, 16 Aug 2020 20:17:02 -0500
Subject: [PATCH] fix map_if handling of scalar conditions

---
 grudge/execution.py | 12 ++++++------
 test/test_grudge.py | 22 ++++++++++++++++++++++
 2 files changed, 28 insertions(+), 6 deletions(-)

diff --git a/grudge/execution.py b/grudge/execution.py
index 88c0d318..0f6711e8 100644
--- a/grudge/execution.py
+++ b/grudge/execution.py
@@ -182,10 +182,10 @@ class ExecutionMapper(mappers.Evaluator,
     def map_if(self, expr):
         bool_crit = self.rec(expr.condition)
 
-        if isinstance(bool_crit,  DOFArray):
+        if isinstance(bool_crit, DOFArray):
             # continues below
             pass
-        elif isinstance(bool_crit,  np.number):
+        elif isinstance(bool_crit, (np.bool_, np.bool, np.number)):
             if bool_crit:
                 return self.rec(expr.then)
             else:
@@ -194,7 +194,7 @@ class ExecutionMapper(mappers.Evaluator,
             raise TypeError(
                 "Expected criterion to be of type np.number or DOFArray")
 
-        assert isinstance(bool_crit,  DOFArray)
+        assert isinstance(bool_crit, DOFArray)
         ngroups = len(bool_crit)
 
         from pymbolic import var
@@ -208,7 +208,7 @@ class ExecutionMapper(mappers.Evaluator,
         import pymbolic.primitives as p
         var = p.Variable
 
-        if isinstance(then,  DOFArray):
+        if isinstance(then, DOFArray):
             sym_then = var("a")[subscript]
 
             def get_then(igrp):
@@ -222,12 +222,12 @@ class ExecutionMapper(mappers.Evaluator,
             raise TypeError(
                 "Expected 'then' to be of type np.number or DOFArray")
 
-        if isinstance(else_,  DOFArray):
+        if isinstance(else_, DOFArray):
             sym_else = var("b")[subscript]
 
             def get_else(igrp):
                 return else_[igrp]
-        elif isinstance(else_,  np.number):
+        elif isinstance(else_, np.number):
             sym_else = var("b")
 
             def get_else(igrp):
diff --git a/test/test_grudge.py b/test/test_grudge.py
index b6d82a3f..2d29c9d7 100644
--- a/test/test_grudge.py
+++ b/test/test_grudge.py
@@ -613,6 +613,8 @@ def test_external_call(ctx_factory):
 
 @pytest.mark.parametrize("array_type", ["scalar", "vector"])
 def test_function_symbol_array(ctx_factory, array_type):
+    """Test if `FunctionSymbol` distributed properly over object arrays."""
+
     ctx = ctx_factory()
     queue = cl.CommandQueue(ctx)
     actx = PyOpenCLArrayContext(queue)
@@ -645,6 +647,26 @@ def test_function_symbol_array(ctx_factory, array_type):
     assert isinstance(norm, float)
 
 
+def test_map_if(ctx_factory):
+    """Test :meth:`grudge.symbolic.execution.ExecutionMapper.map_if` handling
+    of scalar conditions.
+    """
+
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+    actx = PyOpenCLArrayContext(queue)
+
+    from meshmode.mesh.generation import generate_regular_rect_mesh
+    dim = 2
+    mesh = generate_regular_rect_mesh(
+            a=(-0.5,)*dim, b=(0.5,)*dim,
+            n=(8,)*dim, order=4)
+    discr = DGDiscretizationWithBoundaries(actx, mesh, order=4)
+
+    sym_if = sym.If(sym.Comparison(2.0, "<", 1.0e-14), 1.0, 2.0)
+    bind(discr, sym_if)(actx)
+
+
 # You can test individual routines by typing
 # $ python test_grudge.py 'test_routine()'
 
-- 
GitLab