From 791820ca8dff17d4c1878fee9c8ae1c083782ed6 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 18 Apr 2012 15:46:22 -0400 Subject: [PATCH] Fix type inference for temporaries. --- loopy/codegen/expression.py | 5 ++++- loopy/preprocess.py | 28 +++++++++++++++++++--------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index 3ccb5d1ed..d77cdbd2a 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -11,6 +11,9 @@ from pymbolic.mapper import CombineMapper class TypeInferenceFailure(RuntimeError): pass +class DependencyTypeInferenceFailure(TypeInferenceFailure): + pass + class TypeInferenceMapper(CombineMapper): def __init__(self, kernel, temporary_variables=None): self.kernel = kernel @@ -77,7 +80,7 @@ class TypeInferenceMapper(CombineMapper): else: from loopy import infer_type if tv.dtype is infer_type: - raise TypeInferenceFailure("attempted type inference on " + raise DependencyTypeInferenceFailure("attempted type inference on " "variable requiring type inference") return tv.dtype diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 3f00f85d6..e010b2e78 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -17,26 +17,35 @@ def infer_types_of_temporaries(kernel): for tv in kernel.temporary_variables.itervalues(): if tv.dtype is infer_type: queue.append(tv) + new_temp_vars[tv.name] = tv else: new_temp_vars[tv.name] = tv from loopy.codegen.expression import ( - TypeInferenceMapper, TypeInferenceFailure) - tim = TypeInferenceMapper(kernel) + TypeInferenceMapper, DependencyTypeInferenceFailure) + tim = TypeInferenceMapper(kernel, new_temp_vars) first_failure = None while queue: tv = queue.pop(0) - try: - dtypes = [ - tim(kernel.id_to_insn[writer].expression) - for writer in kernel.writer_map()[tv.name]] - except TypeInferenceFailure: + dtypes = [] + + writers = kernel.writer_map()[tv.name] + exprs = [kernel.id_to_insn[w].expression for w in writers] + + for expr in exprs: + try: + dtypes.append(tim(expr)) + except DependencyTypeInferenceFailure: + pass + + if not dtypes: if tv is first_failure: # this has failed before, give up. - raise RuntimeError("could not determine type of '%s'" - % tv.name) + raise RuntimeError("could not determine type of '%s' from expression(s) '%s'" + % (tv.name, + ", ".join(str(e) for e in exprs))) if first_failure is None: # remember the first failure for this round through the queue @@ -44,6 +53,7 @@ def infer_types_of_temporaries(kernel): # can't infer type yet, put back into queue queue.append(tv) + continue from pytools import is_single_valued if not is_single_valued(dtypes): -- GitLab