Skip to content
Snippets Groups Projects
Commit 5021fd9e authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add a facility for type inference on temporary variables.

parent 94b1b76e
No related branches found
No related tags found
No related merge requests found
build
.*.sw[po]
.sw[po]
*~
*.pyc
*.pyo
......
......@@ -50,6 +50,9 @@ optional.
* `<float32>` declares `lhs` as a temporary variable, with shape given
by the ranges of the `lhs` subscripts. (Note that in this case, the
`lhs` subscripts must be pure inames, not expressions, for now.)
Instead of a concrete type, an empty set of angle brackets `<>` may be
given to indicate that type inference should figure out the type of the
temporary.
* `[i,j|k,l]` specifies the inames within which this instruction is run.
Independent copies of the inames `k` and `l` will be made for this
......
......@@ -46,6 +46,9 @@ __all__ = ["ScalarArg", "ArrayArg", "ConstantArrayArg", "ImageArg", "LoopKernel"
"precompute", "add_prefetch"
]
class infer_type:
pass
# }}}
# {{{ dimension split
......
......@@ -8,6 +8,9 @@ from pymbolic.mapper import CombineMapper
# {{{ type inference
class TypeInferenceFailure(RuntimeError):
pass
class TypeInferenceMapper(CombineMapper):
def __init__(self, kernel, temporary_variables=None):
self.kernel = kernel
......@@ -33,7 +36,7 @@ class TypeInferenceMapper(CombineMapper):
pass
else:
if not result is other:
raise TypeError("nothing known about result of operation on "
raise TypeInferenceFailure("nothing known about result of operation on "
"'%s' and '%s'" % (result, other))
return result
......@@ -60,14 +63,20 @@ class TypeInferenceMapper(CombineMapper):
pass
try:
return self.temporary_variables[expr.name].dtype
result = self.temporary_variables[expr.name].dtype
except KeyError:
pass
else:
from loopy import infer_type
if result is infer_type:
raise TypeInferenceFailure("attempted type inference on "
"variable requiring type inference")
return result
if expr.name in self.kernel.all_inames():
return np.dtype(np.int16) # don't force single-precision upcast
raise RuntimeError("type inference: nothing known about '%s'" % expr.name)
raise TypeInferenceFailure("nothing known about '%s'" % expr.name)
def map_lookup(self, expr):
agg_result = self.rec(expr.aggregate)
......
......@@ -96,6 +96,9 @@ def create_temporaries(knl):
new_insns = []
new_temp_vars = knl.temporary_variables.copy()
from loopy.codegen.expression import TypeInferenceMapper
tim = TypeInferenceMapper(knl, new_temp_vars)
for insn in knl.instructions:
from loopy.kernel import (
find_var_base_indices_and_shape_from_inames,
......@@ -104,6 +107,13 @@ def create_temporaries(knl):
if insn.temp_var_type is not None:
assignee_name = insn.get_assignee_var_name()
temp_var_type = insn.temp_var_type
from loopy import infer_type
if temp_var_type is infer_type:
# FIXME dependencies among type-inferred variables
# are not allowed yet.
temp_var_type = tim(insn.expression)
assignee_indices = []
from pymbolic.primitives import Variable
for index_expr in insn.get_assignee_indices():
......@@ -123,7 +133,7 @@ def create_temporaries(knl):
new_temp_vars[assignee_name] = TemporaryVariable(
name=assignee_name,
dtype=np.dtype(insn.temp_var_type),
dtype=temp_var_type,
is_local=None,
base_indices=base_indices,
shape=shape)
......
......@@ -293,6 +293,12 @@ class Instruction(Record):
boostable_into=None,
temp_var_type=None, duplicate_inames_and_tags=[]):
from loopy.symbolic import parse
if isinstance(assignee, str):
assignee = parse(assignee)
if isinstance(expression, str):
assignee = parse(expression)
assert isinstance(forced_iname_deps, set)
assert isinstance(insn_deps, set)
......@@ -557,7 +563,7 @@ class LoopKernel(Record):
"(?P<iname_deps_and_tags>[\s\w,:.]*)"
"(?:\|(?P<duplicate_inames_and_tags>[\s\w,:.]*))?"
"\])?"
"\s*(?:\<(?P<temp_var_type>.+?)\>)?"
"\s*(?:\<(?P<temp_var_type>.*?)\>)?"
"\s*(?P<lhs>.+?)\s*(?<!\:)=\s*(?P<rhs>.+?)"
"\s*?(?:\:\s*(?P<insn_deps>[\s\w,]+))?$"
)
......@@ -648,7 +654,11 @@ class LoopKernel(Record):
duplicate_inames_and_tags = []
if groups["temp_var_type"] is not None:
temp_var_type = groups["temp_var_type"]
if groups["temp_var_type"]:
temp_var_type = np.dtype(groups["temp_var_type"])
else:
from loopy import infer_type
temp_var_type = infer_type
else:
temp_var_type = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment