From d1617b1106c02adf0d45c35bfa5db5e5e495b740 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Wed, 26 May 2021 13:14:43 -0500
Subject: [PATCH] avoid 'if' by appropriately using sub-classes

---
 loopy/library/reduction.py | 107 ++++++++++++++++++++-----------------
 1 file changed, 57 insertions(+), 50 deletions(-)

diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py
index 13dfadedd..67043e1af 100644
--- a/loopy/library/reduction.py
+++ b/loopy/library/reduction.py
@@ -334,8 +334,7 @@ class _SegmentedScalarReductionOperation(ReductionOperation):
                 other.inner_reduction)
 
     def __call__(self, dtypes, operand1, operand2, callables_table, target):
-        segmented_scalar_callable = ReductionCallable(
-                SegmentedOp(self))
+        segmented_scalar_callable = SegmentOpCallable(SegmentedOp(self))
 
         # type specialize the callable
         segmented_scalar_callable, callables_table = (
@@ -440,7 +439,7 @@ class _ArgExtremumReductionOperation(ReductionOperation):
         return 2
 
     def __call__(self, dtypes, operand1, operand2, callables_table, target):
-        arg_ext_scalar_callable = ReductionCallable(ArgExtOp(self))
+        arg_ext_scalar_callable = ArgExtOpCallable(ArgExtOp(self))
 
         # type specialize the callable
         arg_ext_scalar_callable, callables_table = (
@@ -561,58 +560,66 @@ class ReductionCallable(ScalarCallable):
                 self.copy(arg_id_to_descr=arg_id_to_descr),
                 callables_table)
 
+
+class ArgExtOpCallable(ReductionCallable):
+
     def generate_preambles(self, target):
-        if isinstance(self.name, ArgExtOp):
-            op = self.name.reduction_op
-            scalar_dtype = self.arg_id_to_dtype[-1]
-            index_dtype = self.arg_id_to_dtype[-2]
-
-            prefix = op.prefix(scalar_dtype, index_dtype)
-
-            yield (prefix, """
-            inline {scalar_t} {prefix}_op(
-                {scalar_t} op1, {index_t} index1,
-                {scalar_t} op2, {index_t} index2,
-                {index_t} *index_out)
+        op = self.name.reduction_op
+        scalar_dtype = self.arg_id_to_dtype[-1]
+        index_dtype = self.arg_id_to_dtype[-2]
+
+        prefix = op.prefix(scalar_dtype, index_dtype)
+
+        yield (prefix, """
+        inline {scalar_t} {prefix}_op(
+            {scalar_t} op1, {index_t} index1,
+            {scalar_t} op2, {index_t} index2,
+            {index_t} *index_out)
+        {{
+            if (op2 {comp} op1)
             {{
-                if (op2 {comp} op1)
-                {{
-                    *index_out = index2;
-                    return op2;
-                }}
-                else
-                {{
-                    *index_out = index1;
-                    return op1;
-                }}
+                *index_out = index2;
+                return op2;
             }}
-            """.format(
-                   scalar_t=target.dtype_to_typename(scalar_dtype),
-                   prefix=prefix,
-                   index_t=target.dtype_to_typename(index_dtype),
-                   comp=op.update_comparison,
-                   ))
-        elif isinstance(self.name, SegmentedOp):
-            op = self.name.reduction_op
-            scalar_dtype = self.arg_id_to_dtype[-1]
-            segment_flag_dtype = self.arg_id_to_dtype[-2]
-            prefix = op.prefix(scalar_dtype, segment_flag_dtype)
-
-            yield (prefix, """
-            inline {scalar_t} {prefix}_op(
-                {scalar_t} op1, {segment_flag_t} segment_flag1,
-                {scalar_t} op2, {segment_flag_t} segment_flag2,
-                {segment_flag_t} *segment_flag_out)
+            else
             {{
-                *segment_flag_out = segment_flag1 | segment_flag2;
-                return segment_flag2 ? op2 : {combined};
+                *index_out = index1;
+                return op1;
             }}
-            """.format(
-                   scalar_t=target.dtype_to_typename(scalar_dtype),
-                   prefix=prefix,
-                   segment_flag_t=target.dtype_to_typename(segment_flag_dtype),
-                   combined=op.op % ("op1", "op2"),
-                   ))
+        }}
+        """.format(
+                scalar_t=target.dtype_to_typename(scalar_dtype),
+                prefix=prefix,
+                index_t=target.dtype_to_typename(index_dtype),
+                comp=op.update_comparison,
+                ))
+
+        return
+
+
+class SegmentOpCallable(ReductionCallable):
+
+    def generate_preambles(self, target):
+        op = self.name.reduction_op
+        scalar_dtype = self.arg_id_to_dtype[-1]
+        segment_flag_dtype = self.arg_id_to_dtype[-2]
+        prefix = op.prefix(scalar_dtype, segment_flag_dtype)
+
+        yield (prefix, """
+        inline {scalar_t} {prefix}_op(
+            {scalar_t} op1, {segment_flag_t} segment_flag1,
+            {scalar_t} op2, {segment_flag_t} segment_flag2,
+            {segment_flag_t} *segment_flag_out)
+        {{
+            *segment_flag_out = segment_flag1 | segment_flag2;
+            return segment_flag2 ? op2 : {combined};
+        }}
+        """.format(
+                scalar_t=target.dtype_to_typename(scalar_dtype),
+                prefix=prefix,
+                segment_flag_t=target.dtype_to_typename(segment_flag_dtype),
+                combined=op.op % ("op1", "op2"),
+                ))
 
         return
 
-- 
GitLab