diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 769301ed99d8486a396dbf94b06bf361771596cb..1f7feb1fa23d487f2d0a2be2338b92409e64bfe7 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -43,6 +43,22 @@ def add_dtypes(knl, dtype_dict): :arg dtype_dict: a mapping from variable names to :class:`numpy.dtype` instances """ + dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes(knl, dtype_dict) + + if dtype_dict_remainder: + raise RuntimeError("unused argument dtypes: %s" + % ", ".join(dtype_dict_remainder)) + + return knl.copy(args=new_args, temporary_variables=new_temp_vars) + + +def _add_dtypes_overdetermined(knl, dtype_dict): + dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes(knl, dtype_dict) + # do not throw error for unused args + return knl.copy(args=new_args, temporary_variables=new_temp_vars) + + +def _add_dtypes(knl, dtype_dict): dtype_dict = dtype_dict.copy() new_args = [] @@ -76,11 +92,7 @@ def add_dtypes(knl, dtype_dict): new_temp_vars[tv_name] = tv.copy(dtype=new_dtype) - if dtype_dict: - raise RuntimeError("unused argument dtypes: %s" - % ", ".join(dtype_dict)) - - return knl.copy(args=new_args, temporary_variables=new_temp_vars) + return dtype_dict, new_args, new_temp_vars def get_arguments_with_incomplete_dtype(knl): @@ -94,6 +106,12 @@ def add_and_infer_dtypes(knl, dtype_dict): from loopy.preprocess import infer_unknown_types return infer_unknown_types(knl, expect_completion=True) +def _add_and_infer_dtypes_overdetermined(knl, dtype_dict): + knl = _add_dtypes_overdetermined(knl, dtype_dict) + + from loopy.preprocess import infer_unknown_types + return infer_unknown_types(knl, expect_completion=True) + # }}}