From 72bd90c18a8579d59d3335ad85e41747385bcd42 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 4 Sep 2013 20:40:58 -0500
Subject: [PATCH] Make add_dtypes also operate on temporary variables, rename
 to match

---
 doc/reference.rst     |  4 ++--
 loopy/__init__.py     |  8 ++++----
 loopy/auto_test.py    |  2 +-
 loopy/compiled.py     | 25 ++++++++++++++++++-------
 loopy/kernel/tools.py | 29 ++++++++++++++++++++++-------
 test/test_loopy.py    |  2 +-
 6 files changed, 48 insertions(+), 22 deletions(-)

diff --git a/doc/reference.rst b/doc/reference.rst
index bf11a96cc..ac443cf62 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -322,11 +322,11 @@ Library interface
 Argument types
 ^^^^^^^^^^^^^^
 
-.. autofunction:: add_argument_dtypes
+.. autofunction:: add_dtypes
 
 .. autofunction:: infer_unknown_types
 
-.. autofunction:: add_and_infer_argument_dtypes
+.. autofunction:: add_and_infer_dtypes
 
 Finishing up
 ^^^^^^^^^^^^
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 439388f4d..a6c44e8ca 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -57,8 +57,8 @@ from loopy.kernel.data import (
 from loopy.kernel import LoopKernel
 from loopy.kernel.tools import (
         get_dot_dependency_graph,
-        add_argument_dtypes,
-        add_and_infer_argument_dtypes)
+        add_dtypes,
+        add_and_infer_dtypes)
 from loopy.kernel.creation import make_kernel, UniqueName
 from loopy.library.reduction import register_reduction_parser
 from loopy.subst import extract_subst, expand_subst
@@ -95,8 +95,8 @@ __all__ = [
         "precompute",
         "split_arg_axis", "find_padding_multiple", "add_padding",
 
-        "get_dot_dependency_graph", "add_argument_dtypes",
-        "infer_argument_dtypes", "add_and_infer_argument_dtypes",
+        "get_dot_dependency_graph", "add_dtypes",
+        "infer_argument_dtypes", "add_and_infer_dtypes",
 
         "preprocess_kernel", "realize_reduction", "infer_unknown_types",
         "generate_loop_schedules",
diff --git a/loopy/auto_test.py b/loopy/auto_test.py
index 6a1abce98..03f7ac0a3 100644
--- a/loopy/auto_test.py
+++ b/loopy/auto_test.py
@@ -108,7 +108,7 @@ def make_ref_args(kernel, impl_arg_info, queue, parameters, fill_value):
                 if dtype is None:
                     raise RuntimeError("dtype for argument '%s' is not yet "
                             "known. Perhaps you want to use "
-                            "loopy.add_argument_dtypes "
+                            "loopy.add_dtypes "
                             "or loopy.infer_argument_dtypes?"
                             % arg.name)
 
diff --git a/loopy/compiled.py b/loopy/compiled.py
index bda857b22..3396f482c 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -30,6 +30,7 @@ from pytools import Record, memoize_method
 from loopy.diagnostic import ParameterFinderWarning
 from pytools.py_codegen import (
         Indentation, PythonFunctionGenerator)
+from loopy.diagnostic import LoopyError
 
 
 # {{{ object array argument packing
@@ -665,17 +666,27 @@ class CompiledKernel:
                 if arg.name in self.kernel.get_written_variables())
 
     @memoize_method
-    def get_kernel(self, arg_to_dtype_set):
+    def get_kernel(self, var_to_dtype_set):
         kernel = self.kernel
 
-        from loopy.kernel.tools import add_argument_dtypes
+        from loopy.kernel.tools import add_dtypes
 
-        if arg_to_dtype_set:
-            arg_to_dtype = {}
-            for arg, dtype in arg_to_dtype_set:
-                arg_to_dtype[kernel.impl_arg_to_arg[arg].name] = dtype
+        if var_to_dtype_set:
+            var_to_dtype = {}
+            for var, dtype in var_to_dtype_set:
+                try:
+                    dest_name = kernel.impl_arg_to_arg[var].name
+                except KeyError:
+                    dest_name = var
+
+                try:
+                    var_to_dtype[dest_name] = dtype
+                except KeyError:
+                    raise LoopyError("cannot set type for '%s': "
+                            "no known variable/argument with that name"
+                            % var)
 
-            kernel = add_argument_dtypes(kernel, arg_to_dtype)
+            kernel = add_dtypes(kernel, var_to_dtype)
 
             from loopy.preprocess import infer_unknown_types
             kernel = infer_unknown_types(kernel, expect_completion=True)
diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py
index 507cc0e34..bf4db6d1b 100644
--- a/loopy/kernel/tools.py
+++ b/loopy/kernel/tools.py
@@ -34,10 +34,10 @@ logger = logging.getLogger(__name__)
 
 # {{{ add and infer argument dtypes
 
-def add_argument_dtypes(knl, dtype_dict):
-    """Specify remaining unspecified argument types.
+def add_dtypes(knl, dtype_dict):
+    """Specify remaining unspecified argument/temporary variable types.
 
-    :arg dtype_dict: a mapping from argument names to :class:`numpy.dtype`
+    :arg dtype_dict: a mapping from variable names to :class:`numpy.dtype`
         instances
     """
     dtype_dict = dtype_dict.copy()
@@ -56,13 +56,28 @@ def add_argument_dtypes(knl, dtype_dict):
 
         new_args.append(arg)
 
-    knl = knl.copy(args=new_args)
+    new_temp_vars = knl.temporary_variables.copy()
+
+    import loopy as lp
+    for tv_name in knl.temporary_variables:
+        new_dtype = dtype_dict.pop(tv_name, None)
+        if new_dtype is not None:
+            new_dtype = np.dtype(new_dtype)
+            tv = new_temp_vars[tv_name]
+            if (tv.dtype is not None and tv.dtype is not lp.auto) \
+                    and tv.dtype != new_dtype:
+                raise RuntimeError(
+                        "temporary variable '%s' already has a different dtype "
+                        "(existing: %s, new: %s)"
+                        % (tv_name, tv.dtype, new_dtype))
+
+            new_temp_vars[tv_name] = tv.copy(dtype=new_dtype)
 
     if dtype_dict:
         raise RuntimeError("unused argument dtypes: %s"
                 % ", ".join(dtype_dict))
 
-    return knl.copy(args=new_args)
+    return knl.copy(args=new_args, temporary_variables=new_temp_vars)
 
 
 def get_arguments_with_incomplete_dtype(knl):
@@ -70,8 +85,8 @@ def get_arguments_with_incomplete_dtype(knl):
             if arg.dtype is None]
 
 
-def add_and_infer_argument_dtypes(knl, dtype_dict):
-    knl = add_argument_dtypes(knl, dtype_dict)
+def add_and_infer_dtypes(knl, dtype_dict):
+    knl = add_dtypes(knl, dtype_dict)
 
     from loopy.preprocess import infer_unknown_types
     return infer_unknown_types(knl, expect_completion=True)
diff --git a/test/test_loopy.py b/test/test_loopy.py
index bd5dbbde1..8092e9c6f 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -337,7 +337,7 @@ def test_stencil_with_overfetch(ctx_factory):
                 ],
             assumptions="n>=1")
 
-    knl = lp.add_and_infer_argument_dtypes(knl, dict(a=np.float32))
+    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32))
 
     ref_knl = knl
 
-- 
GitLab