From d1617b1106c02adf0d45c35bfa5db5e5e495b740 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Wed, 26 May 2021 13:14:43 -0500 Subject: [PATCH] avoid 'if' by appropriately using sub-classes --- loopy/library/reduction.py | 107 ++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 50 deletions(-) diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index 13dfadedd..67043e1af 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -334,8 +334,7 @@ class _SegmentedScalarReductionOperation(ReductionOperation): other.inner_reduction) def __call__(self, dtypes, operand1, operand2, callables_table, target): - segmented_scalar_callable = ReductionCallable( - SegmentedOp(self)) + segmented_scalar_callable = SegmentOpCallable(SegmentedOp(self)) # type specialize the callable segmented_scalar_callable, callables_table = ( @@ -440,7 +439,7 @@ class _ArgExtremumReductionOperation(ReductionOperation): return 2 def __call__(self, dtypes, operand1, operand2, callables_table, target): - arg_ext_scalar_callable = ReductionCallable(ArgExtOp(self)) + arg_ext_scalar_callable = ArgExtOpCallable(ArgExtOp(self)) # type specialize the callable arg_ext_scalar_callable, callables_table = ( @@ -561,58 +560,66 @@ class ReductionCallable(ScalarCallable): self.copy(arg_id_to_descr=arg_id_to_descr), callables_table) + +class ArgExtOpCallable(ReductionCallable): + def generate_preambles(self, target): - if isinstance(self.name, ArgExtOp): - op = self.name.reduction_op - scalar_dtype = self.arg_id_to_dtype[-1] - index_dtype = self.arg_id_to_dtype[-2] - - prefix = op.prefix(scalar_dtype, index_dtype) - - yield (prefix, """ - inline {scalar_t} {prefix}_op( - {scalar_t} op1, {index_t} index1, - {scalar_t} op2, {index_t} index2, - {index_t} *index_out) + op = self.name.reduction_op + scalar_dtype = self.arg_id_to_dtype[-1] + index_dtype = self.arg_id_to_dtype[-2] + + prefix = op.prefix(scalar_dtype, index_dtype) + + yield (prefix, """ + inline {scalar_t} {prefix}_op( + {scalar_t} op1, {index_t} index1, + {scalar_t} op2, {index_t} index2, + {index_t} *index_out) + {{ + if (op2 {comp} op1) {{ - if (op2 {comp} op1) - {{ - *index_out = index2; - return op2; - }} - else - {{ - *index_out = index1; - return op1; - }} + *index_out = index2; + return op2; }} - """.format( - scalar_t=target.dtype_to_typename(scalar_dtype), - prefix=prefix, - index_t=target.dtype_to_typename(index_dtype), - comp=op.update_comparison, - )) - elif isinstance(self.name, SegmentedOp): - op = self.name.reduction_op - scalar_dtype = self.arg_id_to_dtype[-1] - segment_flag_dtype = self.arg_id_to_dtype[-2] - prefix = op.prefix(scalar_dtype, segment_flag_dtype) - - yield (prefix, """ - inline {scalar_t} {prefix}_op( - {scalar_t} op1, {segment_flag_t} segment_flag1, - {scalar_t} op2, {segment_flag_t} segment_flag2, - {segment_flag_t} *segment_flag_out) + else {{ - *segment_flag_out = segment_flag1 | segment_flag2; - return segment_flag2 ? op2 : {combined}; + *index_out = index1; + return op1; }} - """.format( - scalar_t=target.dtype_to_typename(scalar_dtype), - prefix=prefix, - segment_flag_t=target.dtype_to_typename(segment_flag_dtype), - combined=op.op % ("op1", "op2"), - )) + }} + """.format( + scalar_t=target.dtype_to_typename(scalar_dtype), + prefix=prefix, + index_t=target.dtype_to_typename(index_dtype), + comp=op.update_comparison, + )) + + return + + +class SegmentOpCallable(ReductionCallable): + + def generate_preambles(self, target): + op = self.name.reduction_op + scalar_dtype = self.arg_id_to_dtype[-1] + segment_flag_dtype = self.arg_id_to_dtype[-2] + prefix = op.prefix(scalar_dtype, segment_flag_dtype) + + yield (prefix, """ + inline {scalar_t} {prefix}_op( + {scalar_t} op1, {segment_flag_t} segment_flag1, + {scalar_t} op2, {segment_flag_t} segment_flag2, + {segment_flag_t} *segment_flag_out) + {{ + *segment_flag_out = segment_flag1 | segment_flag2; + return segment_flag2 ? op2 : {combined}; + }} + """.format( + scalar_t=target.dtype_to_typename(scalar_dtype), + prefix=prefix, + segment_flag_t=target.dtype_to_typename(segment_flag_dtype), + combined=op.op % ("op1", "op2"), + )) return -- GitLab