Skip to content
Snippets Groups Projects
Commit 791820ca authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fix type inference for temporaries.

parent bedbc2a9
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,9 @@ from pymbolic.mapper import CombineMapper
class TypeInferenceFailure(RuntimeError):
pass
class DependencyTypeInferenceFailure(TypeInferenceFailure):
pass
class TypeInferenceMapper(CombineMapper):
def __init__(self, kernel, temporary_variables=None):
self.kernel = kernel
......@@ -77,7 +80,7 @@ class TypeInferenceMapper(CombineMapper):
else:
from loopy import infer_type
if tv.dtype is infer_type:
raise TypeInferenceFailure("attempted type inference on "
raise DependencyTypeInferenceFailure("attempted type inference on "
"variable requiring type inference")
return tv.dtype
......
......@@ -17,26 +17,35 @@ def infer_types_of_temporaries(kernel):
for tv in kernel.temporary_variables.itervalues():
if tv.dtype is infer_type:
queue.append(tv)
new_temp_vars[tv.name] = tv
else:
new_temp_vars[tv.name] = tv
from loopy.codegen.expression import (
TypeInferenceMapper, TypeInferenceFailure)
tim = TypeInferenceMapper(kernel)
TypeInferenceMapper, DependencyTypeInferenceFailure)
tim = TypeInferenceMapper(kernel, new_temp_vars)
first_failure = None
while queue:
tv = queue.pop(0)
try:
dtypes = [
tim(kernel.id_to_insn[writer].expression)
for writer in kernel.writer_map()[tv.name]]
except TypeInferenceFailure:
dtypes = []
writers = kernel.writer_map()[tv.name]
exprs = [kernel.id_to_insn[w].expression for w in writers]
for expr in exprs:
try:
dtypes.append(tim(expr))
except DependencyTypeInferenceFailure:
pass
if not dtypes:
if tv is first_failure:
# this has failed before, give up.
raise RuntimeError("could not determine type of '%s'"
% tv.name)
raise RuntimeError("could not determine type of '%s' from expression(s) '%s'"
% (tv.name,
", ".join(str(e) for e in exprs)))
if first_failure is None:
# remember the first failure for this round through the queue
......@@ -44,6 +53,7 @@ def infer_types_of_temporaries(kernel):
# can't infer type yet, put back into queue
queue.append(tv)
continue
from pytools import is_single_valued
if not is_single_valued(dtypes):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment