Skip to content
Snippets Groups Projects
Commit d1617b11 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni Committed by Andreas Klöckner
Browse files

avoid 'if' by appropriately using sub-classes

parent 06d9db43
No related branches found
No related tags found
1 merge request!426Discussion: kernel_callables_v3-edit2
...@@ -334,8 +334,7 @@ class _SegmentedScalarReductionOperation(ReductionOperation): ...@@ -334,8 +334,7 @@ class _SegmentedScalarReductionOperation(ReductionOperation):
other.inner_reduction) other.inner_reduction)
def __call__(self, dtypes, operand1, operand2, callables_table, target): def __call__(self, dtypes, operand1, operand2, callables_table, target):
segmented_scalar_callable = ReductionCallable( segmented_scalar_callable = SegmentOpCallable(SegmentedOp(self))
SegmentedOp(self))
# type specialize the callable # type specialize the callable
segmented_scalar_callable, callables_table = ( segmented_scalar_callable, callables_table = (
...@@ -440,7 +439,7 @@ class _ArgExtremumReductionOperation(ReductionOperation): ...@@ -440,7 +439,7 @@ class _ArgExtremumReductionOperation(ReductionOperation):
return 2 return 2
def __call__(self, dtypes, operand1, operand2, callables_table, target): def __call__(self, dtypes, operand1, operand2, callables_table, target):
arg_ext_scalar_callable = ReductionCallable(ArgExtOp(self)) arg_ext_scalar_callable = ArgExtOpCallable(ArgExtOp(self))
# type specialize the callable # type specialize the callable
arg_ext_scalar_callable, callables_table = ( arg_ext_scalar_callable, callables_table = (
...@@ -561,58 +560,66 @@ class ReductionCallable(ScalarCallable): ...@@ -561,58 +560,66 @@ class ReductionCallable(ScalarCallable):
self.copy(arg_id_to_descr=arg_id_to_descr), self.copy(arg_id_to_descr=arg_id_to_descr),
callables_table) callables_table)
class ArgExtOpCallable(ReductionCallable):
def generate_preambles(self, target): def generate_preambles(self, target):
if isinstance(self.name, ArgExtOp): op = self.name.reduction_op
op = self.name.reduction_op scalar_dtype = self.arg_id_to_dtype[-1]
scalar_dtype = self.arg_id_to_dtype[-1] index_dtype = self.arg_id_to_dtype[-2]
index_dtype = self.arg_id_to_dtype[-2]
prefix = op.prefix(scalar_dtype, index_dtype)
prefix = op.prefix(scalar_dtype, index_dtype)
yield (prefix, """
yield (prefix, """ inline {scalar_t} {prefix}_op(
inline {scalar_t} {prefix}_op( {scalar_t} op1, {index_t} index1,
{scalar_t} op1, {index_t} index1, {scalar_t} op2, {index_t} index2,
{scalar_t} op2, {index_t} index2, {index_t} *index_out)
{index_t} *index_out) {{
if (op2 {comp} op1)
{{ {{
if (op2 {comp} op1) *index_out = index2;
{{ return op2;
*index_out = index2;
return op2;
}}
else
{{
*index_out = index1;
return op1;
}}
}} }}
""".format( else
scalar_t=target.dtype_to_typename(scalar_dtype),
prefix=prefix,
index_t=target.dtype_to_typename(index_dtype),
comp=op.update_comparison,
))
elif isinstance(self.name, SegmentedOp):
op = self.name.reduction_op
scalar_dtype = self.arg_id_to_dtype[-1]
segment_flag_dtype = self.arg_id_to_dtype[-2]
prefix = op.prefix(scalar_dtype, segment_flag_dtype)
yield (prefix, """
inline {scalar_t} {prefix}_op(
{scalar_t} op1, {segment_flag_t} segment_flag1,
{scalar_t} op2, {segment_flag_t} segment_flag2,
{segment_flag_t} *segment_flag_out)
{{ {{
*segment_flag_out = segment_flag1 | segment_flag2; *index_out = index1;
return segment_flag2 ? op2 : {combined}; return op1;
}} }}
""".format( }}
scalar_t=target.dtype_to_typename(scalar_dtype), """.format(
prefix=prefix, scalar_t=target.dtype_to_typename(scalar_dtype),
segment_flag_t=target.dtype_to_typename(segment_flag_dtype), prefix=prefix,
combined=op.op % ("op1", "op2"), index_t=target.dtype_to_typename(index_dtype),
)) comp=op.update_comparison,
))
return
class SegmentOpCallable(ReductionCallable):
def generate_preambles(self, target):
op = self.name.reduction_op
scalar_dtype = self.arg_id_to_dtype[-1]
segment_flag_dtype = self.arg_id_to_dtype[-2]
prefix = op.prefix(scalar_dtype, segment_flag_dtype)
yield (prefix, """
inline {scalar_t} {prefix}_op(
{scalar_t} op1, {segment_flag_t} segment_flag1,
{scalar_t} op2, {segment_flag_t} segment_flag2,
{segment_flag_t} *segment_flag_out)
{{
*segment_flag_out = segment_flag1 | segment_flag2;
return segment_flag2 ? op2 : {combined};
}}
""".format(
scalar_t=target.dtype_to_typename(scalar_dtype),
prefix=prefix,
segment_flag_t=target.dtype_to_typename(segment_flag_dtype),
combined=op.op % ("op1", "op2"),
))
return return
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment