From 72bd90c18a8579d59d3335ad85e41747385bcd42 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 4 Sep 2013 20:40:58 -0500 Subject: [PATCH] Make add_dtypes also operate on temporary variables, rename to match --- doc/reference.rst | 4 ++-- loopy/__init__.py | 8 ++++---- loopy/auto_test.py | 2 +- loopy/compiled.py | 25 ++++++++++++++++++------- loopy/kernel/tools.py | 29 ++++++++++++++++++++++------- test/test_loopy.py | 2 +- 6 files changed, 48 insertions(+), 22 deletions(-) diff --git a/doc/reference.rst b/doc/reference.rst index bf11a96cc..ac443cf62 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -322,11 +322,11 @@ Library interface Argument types ^^^^^^^^^^^^^^ -.. autofunction:: add_argument_dtypes +.. autofunction:: add_dtypes .. autofunction:: infer_unknown_types -.. autofunction:: add_and_infer_argument_dtypes +.. autofunction:: add_and_infer_dtypes Finishing up ^^^^^^^^^^^^ diff --git a/loopy/__init__.py b/loopy/__init__.py index 439388f4d..a6c44e8ca 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -57,8 +57,8 @@ from loopy.kernel.data import ( from loopy.kernel import LoopKernel from loopy.kernel.tools import ( get_dot_dependency_graph, - add_argument_dtypes, - add_and_infer_argument_dtypes) + add_dtypes, + add_and_infer_dtypes) from loopy.kernel.creation import make_kernel, UniqueName from loopy.library.reduction import register_reduction_parser from loopy.subst import extract_subst, expand_subst @@ -95,8 +95,8 @@ __all__ = [ "precompute", "split_arg_axis", "find_padding_multiple", "add_padding", - "get_dot_dependency_graph", "add_argument_dtypes", - "infer_argument_dtypes", "add_and_infer_argument_dtypes", + "get_dot_dependency_graph", "add_dtypes", + "infer_argument_dtypes", "add_and_infer_dtypes", "preprocess_kernel", "realize_reduction", "infer_unknown_types", "generate_loop_schedules", diff --git a/loopy/auto_test.py b/loopy/auto_test.py index 6a1abce98..03f7ac0a3 100644 --- a/loopy/auto_test.py +++ b/loopy/auto_test.py @@ -108,7 +108,7 @@ def make_ref_args(kernel, impl_arg_info, queue, parameters, fill_value): if dtype is None: raise RuntimeError("dtype for argument '%s' is not yet " "known. Perhaps you want to use " - "loopy.add_argument_dtypes " + "loopy.add_dtypes " "or loopy.infer_argument_dtypes?" % arg.name) diff --git a/loopy/compiled.py b/loopy/compiled.py index bda857b22..3396f482c 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -30,6 +30,7 @@ from pytools import Record, memoize_method from loopy.diagnostic import ParameterFinderWarning from pytools.py_codegen import ( Indentation, PythonFunctionGenerator) +from loopy.diagnostic import LoopyError # {{{ object array argument packing @@ -665,17 +666,27 @@ class CompiledKernel: if arg.name in self.kernel.get_written_variables()) @memoize_method - def get_kernel(self, arg_to_dtype_set): + def get_kernel(self, var_to_dtype_set): kernel = self.kernel - from loopy.kernel.tools import add_argument_dtypes + from loopy.kernel.tools import add_dtypes - if arg_to_dtype_set: - arg_to_dtype = {} - for arg, dtype in arg_to_dtype_set: - arg_to_dtype[kernel.impl_arg_to_arg[arg].name] = dtype + if var_to_dtype_set: + var_to_dtype = {} + for var, dtype in var_to_dtype_set: + try: + dest_name = kernel.impl_arg_to_arg[var].name + except KeyError: + dest_name = var + + try: + var_to_dtype[dest_name] = dtype + except KeyError: + raise LoopyError("cannot set type for '%s': " + "no known variable/argument with that name" + % var) - kernel = add_argument_dtypes(kernel, arg_to_dtype) + kernel = add_dtypes(kernel, var_to_dtype) from loopy.preprocess import infer_unknown_types kernel = infer_unknown_types(kernel, expect_completion=True) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 507cc0e34..bf4db6d1b 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -34,10 +34,10 @@ logger = logging.getLogger(__name__) # {{{ add and infer argument dtypes -def add_argument_dtypes(knl, dtype_dict): - """Specify remaining unspecified argument types. +def add_dtypes(knl, dtype_dict): + """Specify remaining unspecified argument/temporary variable types. - :arg dtype_dict: a mapping from argument names to :class:`numpy.dtype` + :arg dtype_dict: a mapping from variable names to :class:`numpy.dtype` instances """ dtype_dict = dtype_dict.copy() @@ -56,13 +56,28 @@ def add_argument_dtypes(knl, dtype_dict): new_args.append(arg) - knl = knl.copy(args=new_args) + new_temp_vars = knl.temporary_variables.copy() + + import loopy as lp + for tv_name in knl.temporary_variables: + new_dtype = dtype_dict.pop(tv_name, None) + if new_dtype is not None: + new_dtype = np.dtype(new_dtype) + tv = new_temp_vars[tv_name] + if (tv.dtype is not None and tv.dtype is not lp.auto) \ + and tv.dtype != new_dtype: + raise RuntimeError( + "temporary variable '%s' already has a different dtype " + "(existing: %s, new: %s)" + % (tv_name, tv.dtype, new_dtype)) + + new_temp_vars[tv_name] = tv.copy(dtype=new_dtype) if dtype_dict: raise RuntimeError("unused argument dtypes: %s" % ", ".join(dtype_dict)) - return knl.copy(args=new_args) + return knl.copy(args=new_args, temporary_variables=new_temp_vars) def get_arguments_with_incomplete_dtype(knl): @@ -70,8 +85,8 @@ def get_arguments_with_incomplete_dtype(knl): if arg.dtype is None] -def add_and_infer_argument_dtypes(knl, dtype_dict): - knl = add_argument_dtypes(knl, dtype_dict) +def add_and_infer_dtypes(knl, dtype_dict): + knl = add_dtypes(knl, dtype_dict) from loopy.preprocess import infer_unknown_types return infer_unknown_types(knl, expect_completion=True) diff --git a/test/test_loopy.py b/test/test_loopy.py index bd5dbbde1..8092e9c6f 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -337,7 +337,7 @@ def test_stencil_with_overfetch(ctx_factory): ], assumptions="n>=1") - knl = lp.add_and_infer_argument_dtypes(knl, dict(a=np.float32)) + knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32)) ref_knl = knl -- GitLab