diff --git a/doc/reference.rst b/doc/reference.rst index bf11a96cc2636e15eb19c5f1d6ff3c5db334b6c8..ac443cf624c0f94871e1c481c54b852fe2f6fa14 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -322,11 +322,11 @@ Library interface Argument types ^^^^^^^^^^^^^^ -.. autofunction:: add_argument_dtypes +.. autofunction:: add_dtypes .. autofunction:: infer_unknown_types -.. autofunction:: add_and_infer_argument_dtypes +.. autofunction:: add_and_infer_dtypes Finishing up ^^^^^^^^^^^^ diff --git a/loopy/__init__.py b/loopy/__init__.py index 439388f4d471eb943cafc73a5ab80ed6472db85c..a6c44e8ca7270ccbc32d4cf7dacd38d4f7d71bab 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -57,8 +57,8 @@ from loopy.kernel.data import ( from loopy.kernel import LoopKernel from loopy.kernel.tools import ( get_dot_dependency_graph, - add_argument_dtypes, - add_and_infer_argument_dtypes) + add_dtypes, + add_and_infer_dtypes) from loopy.kernel.creation import make_kernel, UniqueName from loopy.library.reduction import register_reduction_parser from loopy.subst import extract_subst, expand_subst @@ -95,8 +95,8 @@ __all__ = [ "precompute", "split_arg_axis", "find_padding_multiple", "add_padding", - "get_dot_dependency_graph", "add_argument_dtypes", - "infer_argument_dtypes", "add_and_infer_argument_dtypes", + "get_dot_dependency_graph", "add_dtypes", + "infer_argument_dtypes", "add_and_infer_dtypes", "preprocess_kernel", "realize_reduction", "infer_unknown_types", "generate_loop_schedules", diff --git a/loopy/auto_test.py b/loopy/auto_test.py index 6a1abce9805ebd6653b5f1a0d4319b3ea8008843..03f7ac0a3e6132e028622ee680c17cccec4ec12e 100644 --- a/loopy/auto_test.py +++ b/loopy/auto_test.py @@ -108,7 +108,7 @@ def make_ref_args(kernel, impl_arg_info, queue, parameters, fill_value): if dtype is None: raise RuntimeError("dtype for argument '%s' is not yet " "known. Perhaps you want to use " - "loopy.add_argument_dtypes " + "loopy.add_dtypes " "or loopy.infer_argument_dtypes?" % arg.name) diff --git a/loopy/compiled.py b/loopy/compiled.py index bda857b224eff646cc762420e1e8a33bee362f30..3396f482c517f1add545fc149154e2bb0afcf3b9 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -30,6 +30,7 @@ from pytools import Record, memoize_method from loopy.diagnostic import ParameterFinderWarning from pytools.py_codegen import ( Indentation, PythonFunctionGenerator) +from loopy.diagnostic import LoopyError # {{{ object array argument packing @@ -665,17 +666,27 @@ class CompiledKernel: if arg.name in self.kernel.get_written_variables()) @memoize_method - def get_kernel(self, arg_to_dtype_set): + def get_kernel(self, var_to_dtype_set): kernel = self.kernel - from loopy.kernel.tools import add_argument_dtypes + from loopy.kernel.tools import add_dtypes - if arg_to_dtype_set: - arg_to_dtype = {} - for arg, dtype in arg_to_dtype_set: - arg_to_dtype[kernel.impl_arg_to_arg[arg].name] = dtype + if var_to_dtype_set: + var_to_dtype = {} + for var, dtype in var_to_dtype_set: + try: + dest_name = kernel.impl_arg_to_arg[var].name + except KeyError: + dest_name = var + + try: + var_to_dtype[dest_name] = dtype + except KeyError: + raise LoopyError("cannot set type for '%s': " + "no known variable/argument with that name" + % var) - kernel = add_argument_dtypes(kernel, arg_to_dtype) + kernel = add_dtypes(kernel, var_to_dtype) from loopy.preprocess import infer_unknown_types kernel = infer_unknown_types(kernel, expect_completion=True) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 507cc0e34e8afc980387bb90dc8f6c7f3e118281..bf4db6d1bf4d3d92205fa5ea906eaeb767e2c0a7 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -34,10 +34,10 @@ logger = logging.getLogger(__name__) # {{{ add and infer argument dtypes -def add_argument_dtypes(knl, dtype_dict): - """Specify remaining unspecified argument types. +def add_dtypes(knl, dtype_dict): + """Specify remaining unspecified argument/temporary variable types. - :arg dtype_dict: a mapping from argument names to :class:`numpy.dtype` + :arg dtype_dict: a mapping from variable names to :class:`numpy.dtype` instances """ dtype_dict = dtype_dict.copy() @@ -56,13 +56,28 @@ def add_argument_dtypes(knl, dtype_dict): new_args.append(arg) - knl = knl.copy(args=new_args) + new_temp_vars = knl.temporary_variables.copy() + + import loopy as lp + for tv_name in knl.temporary_variables: + new_dtype = dtype_dict.pop(tv_name, None) + if new_dtype is not None: + new_dtype = np.dtype(new_dtype) + tv = new_temp_vars[tv_name] + if (tv.dtype is not None and tv.dtype is not lp.auto) \ + and tv.dtype != new_dtype: + raise RuntimeError( + "temporary variable '%s' already has a different dtype " + "(existing: %s, new: %s)" + % (tv_name, tv.dtype, new_dtype)) + + new_temp_vars[tv_name] = tv.copy(dtype=new_dtype) if dtype_dict: raise RuntimeError("unused argument dtypes: %s" % ", ".join(dtype_dict)) - return knl.copy(args=new_args) + return knl.copy(args=new_args, temporary_variables=new_temp_vars) def get_arguments_with_incomplete_dtype(knl): @@ -70,8 +85,8 @@ def get_arguments_with_incomplete_dtype(knl): if arg.dtype is None] -def add_and_infer_argument_dtypes(knl, dtype_dict): - knl = add_argument_dtypes(knl, dtype_dict) +def add_and_infer_dtypes(knl, dtype_dict): + knl = add_dtypes(knl, dtype_dict) from loopy.preprocess import infer_unknown_types return infer_unknown_types(knl, expect_completion=True) diff --git a/test/test_loopy.py b/test/test_loopy.py index bd5dbbde1f853e077987ee62a8b6180713220424..8092e9c6f87a0c40228a427dcbc20511b3531d10 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -337,7 +337,7 @@ def test_stencil_with_overfetch(ctx_factory): ], assumptions="n>=1") - knl = lp.add_and_infer_argument_dtypes(knl, dict(a=np.float32)) + knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32)) ref_knl = knl