From 657b2ae29a006ca3e90c460090fc716b1511434f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 28 Oct 2013 18:37:13 -0500
Subject: [PATCH] Provide better error message in case of failed type inference
 from missing argument

---
 loopy/codegen/expression.py |  6 ++++--
 loopy/compiled.py           | 10 +++++-----
 loopy/diagnostic.py         |  4 +++-
 loopy/preprocess.py         | 18 ++++++++++++------
 4 files changed, 24 insertions(+), 14 deletions(-)

diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py
index a672c55c2..322a7c3ae 100644
--- a/loopy/codegen/expression.py
+++ b/loopy/codegen/expression.py
@@ -191,7 +191,8 @@ class TypeInferenceMapper(CombineMapper):
             result = obj.dtype
             if result is lp.auto:
                 raise DependencyTypeInferenceFailure(
-                        "temporary variable '%s'" % expr.name)
+                        "temporary variable '%s'" % expr.name,
+                        expr.name)
             else:
                 return result
 
@@ -199,7 +200,8 @@ class TypeInferenceMapper(CombineMapper):
             result = obj.dtype
             if result is None:
                 raise DependencyTypeInferenceFailure(
-                        "argument '%s'" % expr.name)
+                        "argument '%s'" % expr.name,
+                        expr.name)
             else:
                 return result
 
diff --git a/loopy/compiled.py b/loopy/compiled.py
index 3396f482c..53d82a394 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -244,7 +244,7 @@ def generate_integer_arg_finding_from_offsets(gen, kernel, impl_arg_info, flags)
 # }}}
 
 
-# {{{ integer arg finding from offsets
+# {{{ integer arg finding from strides
 
 def generate_integer_arg_finding_from_strides(gen, kernel, impl_arg_info, flags):
     gen("# {{{ find integer arguments from strides")
@@ -666,7 +666,7 @@ class CompiledKernel:
                 if arg.name in self.kernel.get_written_variables())
 
     @memoize_method
-    def get_kernel(self, var_to_dtype_set):
+    def get_typed_and_scheduled_kernel(self, var_to_dtype_set):
         kernel = self.kernel
 
         from loopy.kernel.tools import add_dtypes
@@ -698,8 +698,8 @@ class CompiledKernel:
         return kernel
 
     @memoize_method
-    def cl_kernel_info(self, arg_to_dtype_set=frozenset()):
-        kernel = self.get_kernel(arg_to_dtype_set)
+    def cl_kernel_info(self, arg_to_dtype_set=frozenset(), all_kwargs=None):
+        kernel = self.get_typed_and_scheduled_kernel(arg_to_dtype_set)
 
         from loopy.codegen import generate_code
         code, impl_arg_info = generate_code(kernel, **self.codegen_kwargs)
@@ -730,7 +730,7 @@ class CompiledKernel:
         if arg_to_dtype is not None:
             arg_to_dtype = frozenset(arg_to_dtype.iteritems())
 
-        kernel = self.get_kernel(arg_to_dtype)
+        kernel = self.get_typed_and_scheduled_kernel(arg_to_dtype)
 
         from loopy.codegen import generate_code
         code, arg_info = generate_code(kernel, **self.codegen_kwargs)
diff --git a/loopy/diagnostic.py b/loopy/diagnostic.py
index fc9c55be8..eb301045f 100644
--- a/loopy/diagnostic.py
+++ b/loopy/diagnostic.py
@@ -75,7 +75,9 @@ class TypeInferenceFailure(LoopyError):
 
 
 class DependencyTypeInferenceFailure(TypeInferenceFailure):
-    pass
+    def __init__(self, message, symbol):
+        TypeInferenceFailure.__init__(self, message)
+        self.symbol = symbol
 
 # }}}
 
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index a656ce6fe..8d9f4093a 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
 
 def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander):
     if var_name in kernel.all_params():
-        return kernel.index_dtype
+        return kernel.index_dtype, []
 
     def debug(s):
         logger.debug("%s: %s" % (kernel.name, s))
@@ -46,7 +46,9 @@ def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander):
 
     import loopy as lp
 
-    from loopy.codegen.expression import DependencyTypeInferenceFailure
+    symbols_with_unavailable_types = []
+
+    from loopy.diagnostic import DependencyTypeInferenceFailure
     for writer_insn_id in kernel.writer_map().get(var_name, []):
         writer_insn = kernel.id_to_insn[writer_insn_id]
         if not isinstance(writer_insn, lp.ExpressionInstruction):
@@ -64,16 +66,17 @@ def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander):
 
         except DependencyTypeInferenceFailure, e:
             debug("             failed: %s" % e)
+            symbols_with_unavailable_types.append(e.symbol)
 
     if not dtypes:
-        return None
+        return None, symbols_with_unavailable_types
 
     from pytools import is_single_valued
     if not is_single_valued(dtypes):
         raise LoopyError("ambiguous type inference for '%s'"
                 % var_name)
 
-    return dtypes[0]
+    return dtypes[0], []
 
 
 class _DictUnionView:
@@ -153,7 +156,8 @@ def infer_unknown_types(kernel, expect_completion=False):
 
         debug("inferring type for %s %s" % (type(item).__name__, item.name))
 
-        result = _infer_var_type(kernel, item.name, type_inf_mapper, subst_expander)
+        result, symbols_with_unavailable_types = \
+                _infer_var_type(kernel, item.name, type_inf_mapper, subst_expander)
 
         failed = result is None
         if not failed:
@@ -172,7 +176,9 @@ def infer_unknown_types(kernel, expect_completion=False):
                 # this item has failed before, give up.
                 if expect_completion:
                     raise LoopyError(
-                            "could not determine type of '%s'" % item.name)
+                            "could not determine type of '%s' "
+                            "(need type of '%s'--check for missing arguments)"
+                            % (item.name, ", ".join(symbols_with_unavailable_types)))
                 else:
                     # We're done here.
                     break
-- 
GitLab