From e20729c5c2a2d9437dfb8c23d74d4188d7ca0f7f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 12 Jun 2014 08:56:51 +0100 Subject: [PATCH] Implement loopy.assume --- doc/reference.rst | 4 ++-- loopy/__init__.py | 22 ++++++++++++++++++++++ loopy/kernel/__init__.py | 22 +++++++++++++--------- test/test_loopy.py | 22 ++++++++++++++++++++++ 4 files changed, 59 insertions(+), 11 deletions(-) diff --git a/doc/reference.rst b/doc/reference.rst index f7347c118..4892b1ad7 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -362,6 +362,8 @@ Dealing with Parameters .. autofunction:: fix_parameters +.. autofunction:: assume + Dealing with Substitution Rules ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -376,8 +378,6 @@ Caching, Precomputation and Prefetching .. autofunction:: add_prefetch - Uses :func:`extract_subst` and :func:`precompute`. - Influencing data access ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/loopy/__init__.py b/loopy/__init__.py index d16d01597..5d9ad93e7 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -908,6 +908,8 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None, directly by putting an index expression into *var_name*. Substitutions such as those occurring in dimension splits are recorded and also applied to these indices. + + This function combines :func:`extract_subst` and :func:`precompute`. """ # {{{ fish indexing out of var_name and into footprint_subscripts @@ -1259,6 +1261,26 @@ def fix_parameters(kernel, **value_dict): # }}} +# {{{ assume + +def assume(kernel, assumptions): + if isinstance(assumptions, str): + assumptions_set_str = "[%s] -> { : %s}" \ + % (",".join(s for s in kernel.outer_params()), + assumptions) + assumptions = isl.BasicSet.read_from_str(kernel.domains[0].get_ctx(), + assumptions_set_str) + + if not isinstance(assumptions, isl.BasicSet): + raise TypeError("'assumptions' must be a BasicSet or a string") + + old_assumptions, new_assumptions = isl.align_two(kernel.assumptions, assumptions) + + return kernel.copy(assumptions=old_assumptions.params() & new_assumptions.params()) + +# }}} + + # {{{ set_options def set_options(kernel, *args, **kwargs): diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 4fa3e05ca..0c9e4d3d1 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -225,16 +225,8 @@ class LoopKernel(RecordWithoutPickling): assumptions = isl.BasicSet.universe(assumptions_space) elif isinstance(assumptions, str): - all_inames = set() - all_params = set() - for dom in domains: - all_inames.update(dom.get_var_names(dim_type.set)) - all_params.update(dom.get_var_names(dim_type.param)) - - domain_parameters = all_params-all_inames - assumptions_set_str = "[%s] -> { : %s}" \ - % (",".join(s for s in domain_parameters), + % (",".join(s for s in self.outer_params(domains)), assumptions) assumptions = isl.BasicSet.read_from_str(domains[0].get_ctx(), assumptions_set_str) @@ -591,6 +583,18 @@ class LoopKernel(RecordWithoutPickling): return frozenset(result) + def outer_params(self, domains=None): + if domains is None: + domains = self.domains + + all_inames = set() + all_params = set() + for dom in self.domains: + all_inames.update(dom.get_var_names(dim_type.set)) + all_params.update(dom.get_var_names(dim_type.param)) + + return all_params-all_inames + @memoize_method def all_insn_inames(self): """Return a mapping from instruction ids to inames inside which diff --git a/test/test_loopy.py b/test/test_loopy.py index f9e4eee78..d17ec1c0c 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -127,6 +127,28 @@ def test_sized_and_complex_literals(ctx_factory): lp.auto_test_vs_ref(knl, ctx, knl, parameters=dict(n=5)) +def test_assume(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[i]: 0<=i<n}", + "a[i] = a[i] + 1", + [lp.GlobalArg("a", np.float32, shape="n"), "..."]) + + knl = lp.split_iname(knl, "i", 16) + knl = lp.set_loop_priority(knl, "i_outer,i_inner") + knl = lp.assume(knl, "n mod 16 = 0") + knl = lp.assume(knl, "n > 10") + knl = lp.preprocess_kernel(knl, ctx.devices[0]) + kernel_gen = lp.generate_loop_schedules(knl) + + for gen_knl in kernel_gen: + print gen_knl + compiled = lp.CompiledKernel(ctx, gen_knl) + print compiled.get_code() + assert "if" not in compiled.get_code() + + def test_simple_side_effect(ctx_factory): ctx = ctx_factory() -- GitLab