Skip to content
Snippets Groups Projects
Commit e3bb8f37 authored by Matt Wala's avatar Matt Wala
Browse files

Add a make_tuple() function to loopy.

This function does trivial things, but it's there to solve the problem
that the reduction neutral element getters are not allowed to store
dtypes (#80).

The function mangler demands that a function knows its type based on
its arguments. For the neutral element getters, this is impossible
because they take zero arguments. The simplest fix I can think of is
to change a call to neutral_element() to a call to make_tuple().

Currently, the tuple code doesn't work yet due to pickling issues.  I
think the root cause is somewhere in
__hackily_ensure_multi_argument_functions_are_scoped_private().
parent 282abebe
No related branches found
No related tags found
1 merge request!124Remove numpy dtypes from ArgExtFunction and SegmentedFunction
Pipeline #
...@@ -25,8 +25,9 @@ THE SOFTWARE. ...@@ -25,8 +25,9 @@ THE SOFTWARE.
def default_function_mangler(kernel, name, arg_dtypes): def default_function_mangler(kernel, name, arg_dtypes):
from loopy.library.reduction import reduction_function_mangler from loopy.library.reduction import reduction_function_mangler
from loopy.library.tuple import tuple_function_mangler
manglers = [reduction_function_mangler] manglers = [reduction_function_mangler, tuple_function_mangler]
for mangler in manglers: for mangler in manglers:
result = mangler(kernel, name, arg_dtypes) result = mangler(kernel, name, arg_dtypes)
if result is not None: if result is not None:
......
...@@ -442,7 +442,7 @@ def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander): ...@@ -442,7 +442,7 @@ def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander):
result_i = comp_dtype_set result_i = comp_dtype_set
break break
assert found assert found, var_name
if result_i is not None: if result_i is not None:
result.append(result_i) result.append(result_i)
......
...@@ -176,6 +176,19 @@ def test_random123(ctx_factory, tp): ...@@ -176,6 +176,19 @@ def test_random123(ctx_factory, tp):
assert (0 <= out).all() assert (0 <= out).all()
def test_tuple():
knl = lp.make_kernel(
"{ [i]: 0 <= i < 10 }",
"""
a, b = make_tuple(1, 2)
""")
print(
lp.generate_code(
lp.get_one_scheduled_kernel(
lp.preprocess_kernel(knl)))[0])
def test_clamp(ctx_factory): def test_clamp(ctx_factory):
ctx = ctx_factory() ctx = ctx_factory()
queue = cl.CommandQueue(ctx) queue = cl.CommandQueue(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment