diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index 9f23bcb37802575d3e4ddbca4c6fc4b6c7806cc5..13dfadeddc55cadd5ed66779e3015ca98464b25d 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -203,7 +203,7 @@ class MaxReductionOperation(ScalarReductionOperation): def __call__(self, dtype, operand1, operand2, callables_table, target): dtype, = dtype - from loopy.translation_unit import update_table + from loopy.translation_unit import add_callable_to_table # getting the callable 'max' from target max_scalar_callable = target.get_device_ast_builder().known_callables["max"] @@ -213,7 +213,7 @@ class MaxReductionOperation(ScalarReductionOperation): {0: dtype, 1: dtype}, callables_table) # populate callables_table - func_id, callables_table = update_table(callables_table, "max", + func_id, callables_table = add_callable_to_table(callables_table, "max", max_scalar_callable) return ResolvedFunction(func_id)(operand1, operand2), callables_table @@ -225,7 +225,7 @@ class MinReductionOperation(ScalarReductionOperation): def __call__(self, dtype, operand1, operand2, callables_table, target): dtype, = dtype - from loopy.translation_unit import update_table + from loopy.translation_unit import add_callable_to_table # getting the callable 'min' from target min_scalar_callable = target.get_device_ast_builder().known_callables["min"] @@ -235,7 +235,7 @@ class MinReductionOperation(ScalarReductionOperation): {0: dtype, 1: dtype}, callables_table) # populate callables_table - func_id, callables_table = update_table(callables_table, "min", + func_id, callables_table = add_callable_to_table(callables_table, "min", min_scalar_callable) return ResolvedFunction(func_id)(operand1, operand2), callables_table @@ -300,7 +300,7 @@ class _SegmentedScalarReductionOperation(ReductionOperation): def neutral_element(self, scalar_dtype, segment_flag_dtype, callables_table, target): from loopy.library.function import MakeTupleCallable - from loopy.translation_unit import update_table + from loopy.translation_unit import add_callable_to_table scalar_neutral_element, calables_table = ( self.inner_reduction.neutral_element( @@ -313,7 +313,7 @@ class _SegmentedScalarReductionOperation(ReductionOperation): dict(enumerate([scalar_dtype, segment_flag_dtype])), callables_table) - func_id, callables_table = update_table( + func_id, callables_table = add_callable_to_table( callables_table, "make_tuple", make_tuple_callable) return ResolvedFunction(func_id)(scalar_neutral_element, @@ -344,8 +344,8 @@ class _SegmentedScalarReductionOperation(ReductionOperation): callables_table)) # populate callables_table - from loopy.translation_unit import update_table - func_id, callables_table = update_table( + from loopy.translation_unit import add_callable_to_table + func_id, callables_table = add_callable_to_table( callables_table, SegmentedOp(self), segmented_scalar_callable) return (ResolvedFunction(func_id)(*(operand1 + operand2)), @@ -410,7 +410,7 @@ class _ArgExtremumReductionOperation(ReductionOperation): scalar_neutral_element = scalar_neutral_func(scalar_dtype) from loopy.library.function import MakeTupleCallable - from loopy.translation_unit import update_table + from loopy.translation_unit import add_callable_to_table make_tuple_callable = MakeTupleCallable( name="make_tuple") @@ -419,8 +419,9 @@ class _ArgExtremumReductionOperation(ReductionOperation): callables_table) # populate callables_table - func_id, callables_table = update_table(callables_table, "make_tuple", - make_tuple_callable) + func_id, callables_table = add_callable_to_table(callables_table, + "make_tuple", + make_tuple_callable) return ResolvedFunction(func_id)(scalar_neutral_element, index_dtype.numpy_dtype.type(-1)), callables_table @@ -448,8 +449,8 @@ class _ArgExtremumReductionOperation(ReductionOperation): callables_table)) # populate callables_table - from loopy.translation_unit import update_table - func_id, callables_table = update_table( + from loopy.translation_unit import add_callable_to_table + func_id, callables_table = add_callable_to_table( callables_table, ArgExtOp(self), arg_ext_scalar_callable) return (ResolvedFunction(func_id)(*(operand1 + operand2)), diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index 250bf3294fe220c348dee1e85797aa31a3b6b700..127e6341a499d21a52046f1628aeabbc8802da38 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -749,7 +749,7 @@ def for_each_kernel(transform): return wraps(transform)(_collective_transform) -def update_table(callables_table, clbl_id, clbl): +def add_callable_to_table(callables_table, clbl_id, clbl): """ Returns a tuple ``new_clbl_id, new_callables_table`` where *new_callables_table* is a copy of *callables_table* with *clbl* in its