From bd3222c7de8a0b5169177017821aaa122ad16107 Mon Sep 17 00:00:00 2001 From: James Stevens <jdsteve2@porter.cs.illinois.edu> Date: Sat, 21 Nov 2015 17:15:00 -0600 Subject: [PATCH] added version of add_and_infer_dtypes that accepts type dicts with unused variables --- loopy/kernel/tools.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 769301ed9..1f7feb1fa 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -43,6 +43,22 @@ def add_dtypes(knl, dtype_dict): :arg dtype_dict: a mapping from variable names to :class:`numpy.dtype` instances """ + dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes(knl, dtype_dict) + + if dtype_dict_remainder: + raise RuntimeError("unused argument dtypes: %s" + % ", ".join(dtype_dict_remainder)) + + return knl.copy(args=new_args, temporary_variables=new_temp_vars) + + +def _add_dtypes_overdetermined(knl, dtype_dict): + dtype_dict_remainder, new_args, new_temp_vars = _add_dtypes(knl, dtype_dict) + # do not throw error for unused args + return knl.copy(args=new_args, temporary_variables=new_temp_vars) + + +def _add_dtypes(knl, dtype_dict): dtype_dict = dtype_dict.copy() new_args = [] @@ -76,11 +92,7 @@ def add_dtypes(knl, dtype_dict): new_temp_vars[tv_name] = tv.copy(dtype=new_dtype) - if dtype_dict: - raise RuntimeError("unused argument dtypes: %s" - % ", ".join(dtype_dict)) - - return knl.copy(args=new_args, temporary_variables=new_temp_vars) + return dtype_dict, new_args, new_temp_vars def get_arguments_with_incomplete_dtype(knl): @@ -94,6 +106,12 @@ def add_and_infer_dtypes(knl, dtype_dict): from loopy.preprocess import infer_unknown_types return infer_unknown_types(knl, expect_completion=True) +def _add_and_infer_dtypes_overdetermined(knl, dtype_dict): + knl = _add_dtypes_overdetermined(knl, dtype_dict) + + from loopy.preprocess import infer_unknown_types + return infer_unknown_types(knl, expect_completion=True) + # }}} -- GitLab