Skip to content
Snippets Groups Projects
Commit bd3222c7 authored by James Stevens's avatar James Stevens
Browse files

added version of add_and_infer_dtypes that accepts type dicts with unused variables

parent 7364c0e6
No related branches found
No related tags found
No related merge requests found
......@@ -43,6 +43,22 @@ def add_dtypes(knl, dtype_dict):
:arg dtype_dict: a mapping from variable names to :class:`numpy.dtype`
instances
"""
dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes(knl, dtype_dict)
if dtype_dict_remainder:
raise RuntimeError("unused argument dtypes: %s"
% ", ".join(dtype_dict_remainder))
return knl.copy(args=new_args, temporary_variables=new_temp_vars)
def _add_dtypes_overdetermined(knl, dtype_dict):
dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes(knl, dtype_dict)
# do not throw error for unused args
return knl.copy(args=new_args, temporary_variables=new_temp_vars)
def _add_dtypes(knl, dtype_dict):
dtype_dict = dtype_dict.copy()
new_args = []
......@@ -76,11 +92,7 @@ def add_dtypes(knl, dtype_dict):
new_temp_vars[tv_name] = tv.copy(dtype=new_dtype)
if dtype_dict:
raise RuntimeError("unused argument dtypes: %s"
% ", ".join(dtype_dict))
return knl.copy(args=new_args, temporary_variables=new_temp_vars)
return dtype_dict, new_args, new_temp_vars
def get_arguments_with_incomplete_dtype(knl):
......@@ -94,6 +106,12 @@ def add_and_infer_dtypes(knl, dtype_dict):
from loopy.preprocess import infer_unknown_types
return infer_unknown_types(knl, expect_completion=True)
def _add_and_infer_dtypes_overdetermined(knl, dtype_dict):
knl = _add_dtypes_overdetermined(knl, dtype_dict)
from loopy.preprocess import infer_unknown_types
return infer_unknown_types(knl, expect_completion=True)
# }}}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment