From f85486e087a3209f34cc3a852cd492261f35a6f0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 12 Aug 2012 14:01:08 -0400 Subject: [PATCH] A few more multi-domain fixes. --- loopy/__init__.py | 3 +++ loopy/codegen/bounds.py | 6 ++++-- loopy/isl_helpers.py | 5 ++--- loopy/kernel.py | 5 ++--- test/test_loopy.py | 10 ++++++++-- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/loopy/__init__.py b/loopy/__init__.py index fe7f8859b..2f8fd8f5d 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -80,6 +80,9 @@ def split_dimension(kernel, split_iname, inner_length, inner_iname = split_iname+"_inner" def process_set(s): + if split_iname not in s.get_var_dict(): + return s + outer_var_nr = s.dim(dim_type.set) inner_var_nr = s.dim(dim_type.set)+1 diff --git a/loopy/codegen/bounds.py b/loopy/codegen/bounds.py index 0237295b2..1ec320bef 100644 --- a/loopy/codegen/bounds.py +++ b/loopy/codegen/bounds.py @@ -22,6 +22,7 @@ def get_bounds_constraints(set, iname, admissible_inames, allow_parameters): elim_type.append(dim_type.param) set = set.eliminate_except(admissible_inames, elim_type) + set = set.compute_divs() basic_sets = set.get_basic_sets() if len(basic_sets) > 1: @@ -106,11 +107,12 @@ def constraint_to_code(ccm, cns): return "%s %s 0" % (ccm(constraint_to_expr(cns)), comp_op) def filter_necessary_constraints(implemented_domain, constraints): - space = implemented_domain.get_space() return [cns for cns in constraints if not implemented_domain.is_subset( - isl.Set.universe(space).add_constraint(cns))] + isl.align_spaces( + isl.BasicSet.universe(cns.get_space()).add_constraint(cns), + implemented_domain))] def generate_bounds_checks(domain, check_inames, implemented_domain): """Will not overapproximate.""" diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 3eb1ccba2..dab7648af 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -105,11 +105,10 @@ def dump_space(ls): def make_slab(space, iname, start, stop): zero = isl.Aff.zero_on_domain(space) - from islpy import align_spaces if isinstance(start, (isl.Aff, isl.PwAff)): - start = align_spaces(pw_aff_to_aff(start), zero) + start, zero = isl.align_two(pw_aff_to_aff(start), zero) if isinstance(stop, (isl.Aff, isl.PwAff)): - stop = align_spaces(pw_aff_to_aff(stop), zero) + stop, zero = isl.align_two(pw_aff_to_aff(stop), zero) if isinstance(start, int): start = zero + start if isinstance(stop, int): stop = zero + stop diff --git a/loopy/kernel.py b/loopy/kernel.py index 24ca824e7..36575b103 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -1235,7 +1235,8 @@ class LoopKernel(Record): d_var_dict = domain.get_var_dict() dom_intersect_assumptions = ( - isl.align_spaces(self.assumptions, domain) & domain) + isl.align_spaces(self.assumptions, domain, obj_bigger_ok=True) + & domain) lower_bound_pw_aff = ( self.cache_manager.dim_min( dom_intersect_assumptions, @@ -1614,6 +1615,4 @@ class SetOperationCacheManager: - - # vim: foldmethod=marker diff --git a/test/test_loopy.py b/test/test_loopy.py index 1de7eeb03..37d07f04e 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -368,7 +368,10 @@ def test_dependent_loop_bounds_2(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel(ctx.devices[0], - "[n,row_len] -> {[i,jj]: 0<=i<n and 0<=jj<row_len}", + [ + "{[i]: 0<=i<n}", + "{[jj]: 0<=jj<row_len}", + ], [ "<> row_start = a_rowstarts[i]", "<> row_len = a_rowstarts[i+1] - row_start", @@ -400,7 +403,10 @@ def test_dependent_loop_bounds_3(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel(ctx.devices[0], - "[n,row_len] -> {[i,j]: 0<=i<n and 0<=j<row_len}", + [ + "{[i]: 0<=i<n}", + "{[jj]: 0<=jj<row_len}", + ], [ "<> row_len = a_row_lengths[i]", "a[i,j] = 1", -- GitLab