From 5619ce81c5a2e1c88aed66bd54c902a3401528bc Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 29 Apr 2013 22:59:06 -0400 Subject: [PATCH] Add type inference for kernel arguments. --- examples/hello-loopy.py | 8 +-- loopy/compiled.py | 135 ++++++++++++++++++++++++++++++---------- loopy/kernel.py | 99 ++++++++++++++++++++++++++--- test/test_loopy.py | 8 +-- 4 files changed, 203 insertions(+), 47 deletions(-) diff --git a/examples/hello-loopy.py b/examples/hello-loopy.py index 629a6e6d0..a6b8e6d8b 100644 --- a/examples/hello-loopy.py +++ b/examples/hello-loopy.py @@ -19,9 +19,9 @@ knl = lp.make_kernel(ctx.devices[0], "{[i]: 0<=i" % (self.name, self.dtype) - class ValueArg(Record): - def __init__(self, name, dtype, approximately=None): - Record.__init__(self, name=name, dtype=np.dtype(dtype), + def __init__(self, name, dtype=None, approximately=None): + if dtype is not None: + dtype = np.dtype(dtype) + + Record.__init__(self, name=name, dtype=dtype, approximately=approximately) def __repr__(self): return "" % (self.name, self.dtype) class ScalarArg(ValueArg): - def __init__(self, name, dtype, approximately=None): + def __init__(self, name, dtype=None, approximately=None): from warnings import warn warn("ScalarArg is a deprecated name of ValueArg", DeprecationWarning, stacklevel=2) @@ -1351,6 +1354,16 @@ class LoopKernel(Record): result.update(dom.get_var_names(dim_type.set)) return frozenset(result) + @memoize_method + def all_params(self): + all_inames = self.all_inames() + + result = set() + for dom in self.domains: + result.update(set(dom.get_var_names(dim_type.param)) - all_inames) + + return frozenset(result) + @memoize_method def all_insn_inames(self): """Return a mapping from instruction ids to inames inside which @@ -1678,6 +1691,78 @@ class LoopKernel(Record): # }}} +# {{{ add and infer argument dtypes + +def add_argument_dtypes(knl, dtype_dict): + dtype_dict = dtype_dict.copy() + new_args = [] + + for arg in knl.args: + new_dtype = dtype_dict.pop(arg.name, None) + if new_dtype is not None: + new_dtype = np.dtype(new_dtype) + if arg.dtype is not None and arg.dtype != new_dtype: + raise RuntimeError( + "argument '%s' already has a different dtype " + "(existing: %s, new: %s)" + % (arg.name, arg.dtype, new_dtype)) + arg = arg.copy(dtype=new_dtype) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + + if dtype_dict: + raise RuntimeError("unused argument dtypes: %s" + % ", ".join(dtype_dict)) + + return knl.copy(args=new_args) + +def infer_argument_dtypes(knl): + new_args = [] + + writer_map = knl.writer_map() + + from loopy.codegen.expression import ( + TypeInferenceMapper, TypeInferenceFailure) + tim = TypeInferenceMapper(knl) + + for arg in knl.args: + if arg.dtype is None: + new_dtype = None + + if arg.name in knl.all_params(): + new_dtype = knl.index_dtype + else: + try: + for write_insn_id in writer_map.get(arg.name, ()): + write_insn = knl.id_to_insn[write_insn_id] + new_tim_dtype = tim(write_insn.expression) + if new_dtype is None: + new_dtype = new_tim_dtype + elif new_dtype != new_tim_dtype: + # Now we know *nothing*. + new_dtype = None + break + + except TypeInferenceFailure: + # Even one type inference failure is enough to + # make this dtype not safe to guess. Don't. + pass + + if new_dtype is not None: + arg = arg.copy(dtype=new_dtype) + + new_args.append(arg) + + return knl.copy(args=new_args) + +def get_arguments_with_incomplete_dtype(knl): + return [arg.name for arg in knl.args + if arg.dtype is None] + +# }}} + # {{{ find_all_insn_inames fixed point iteration def find_all_insn_inames(kernel): diff --git a/test/test_loopy.py b/test/test_loopy.py index c9648a46e..16a096c19 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -204,7 +204,7 @@ def test_wg_too_small(ctx_factory): for gen_knl in kernel_gen: try: - lp.CompiledKernel(ctx, gen_knl) + lp.CompiledKernel(ctx, gen_knl).get_code() except RuntimeError, e: assert "implemented and desired" in str(e) pass # expected! @@ -655,7 +655,7 @@ def test_dependent_loop_bounds(ctx_factory): cknl = lp.CompiledKernel(ctx, knl) print "---------------------------------------------------" - cknl.print_code() + print cknl.get_highlighted_code() print "---------------------------------------------------" @@ -689,7 +689,7 @@ def test_dependent_loop_bounds_2(ctx_factory): inner_tag="l.0") cknl = lp.CompiledKernel(ctx, knl) print "---------------------------------------------------" - cknl.print_code() + print cknl.get_highlighted_code() print "---------------------------------------------------" @@ -727,7 +727,7 @@ def test_dependent_loop_bounds_3(ctx_factory): cknl = lp.CompiledKernel(ctx, knl) print "---------------------------------------------------" - cknl.print_code() + print cknl.get_highlighted_code() print "---------------------------------------------------" knl_bad = lp.split_iname(knl, "jj", 128, outer_tag="g.1", -- GitLab