diff --git a/loopy/cse.py b/loopy/cse.py index 87ba115fef17361cde3c27dc595a3606d2c48284..236057ea34f83e0f128d8e8d5f773268816850e9 100644 --- a/loopy/cse.py +++ b/loopy/cse.py @@ -597,8 +597,11 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[], # }}} - if sweep_inames: - leaf_domain_index = kernel.get_leaf_domain_index(frozenset(sweep_inames)) + referenced_inames = frozenset(sweep_inames) | frozenset(usage_arg_deps) + assert referenced_inames <= kernel.all_inames() + + if referenced_inames: + leaf_domain_index = kernel.get_leaf_domain_index(referenced_inames) sweep_domain = kernel.domains[leaf_domain_index] for iname in sweep_inames: @@ -607,6 +610,7 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[], "sweep's leaf domain" % iname) else: sweep_domain = kernel.combine_domains(()) + leaf_domain_index = None (non1_storage_axis_names, new_domain, storage_base_indices, non1_storage_base_indices, non1_storage_shape) = \ @@ -814,8 +818,10 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[], # }}} new_domains = kernel.domains[:] - if sweep_inames: + if leaf_domain_index is not None: new_domains[leaf_domain_index] = new_domain + else: + new_domains.append(new_domain) return kernel.copy( domains=new_domains, diff --git a/loopy/kernel.py b/loopy/kernel.py index c14459298175fff48a900716e053b115880140d8..2f1d1516d7fbf31ea33b414d983549c349a6fb58 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -654,6 +654,7 @@ class LoopKernel(Record): :ivar cache_manager: :ivar lowest_priority_inames: (used internally to realize ILP) :ivar breakable_inames: these inames' loops may be broken up by the scheduler + :ivar isl_context: The following instance variables are only used until :func:`loopy.make_kernel` is finished: @@ -684,7 +685,8 @@ class LoopKernel(Record): cache_manager=None, iname_to_tag_requests=None, lowest_priority_inames=[], breakable_inames=set(), - index_dtype=np.int32): + index_dtype=np.int32, + isl_context=None): """ :arg domain: a :class:`islpy.BasicSet`, or a string parseable to a basic set by the isl. Example: "{[i,j]: 0<=i < 10 and 0<= j < 9}" @@ -888,14 +890,19 @@ class LoopKernel(Record): if isinstance(domains, str): domains = [domains] - ctx = isl.Context() + for domain in domains: + if isinstance(domain, isl.BasicSet): + isl_context = domain.get_ctx() + if isl_context is None: + isl_context = isl.Context() + scalar_arg_names = set(arg.name for arg in args if isinstance(arg, ValueArg)) var_names = ( set(temporary_variables) | set(insn.get_assignee_var_name() for insn in parsed_instructions if insn.temp_var_type is not None)) - domains = _parse_domains(ctx, scalar_arg_names | var_names, domains) + domains = _parse_domains(isl_context, scalar_arg_names | var_names, domains) # }}} @@ -956,7 +963,8 @@ class LoopKernel(Record): applied_iname_rewrites=applied_iname_rewrites, function_manglers=function_manglers, symbol_manglers=symbol_manglers, - index_dtype=index_dtype) + index_dtype=index_dtype, + isl_context=isl_context) # {{{ function mangling @@ -1136,7 +1144,8 @@ class LoopKernel(Record): assert isinstance(domains, tuple) # for caching if not domains: - return isl.BasicSet.universe(self.domains[0].get_space()) + return isl.BasicSet.universe(isl.Space.alloc( + self.isl_context, 0, 0, 0)) result = None for dom_index in domains: diff --git a/test/test_sem_reagan.py b/test/test_sem_reagan.py index f625c3b04e9b39b1005c1577c5990bd10b4de55c..84c9667575aacea9af8a26e91ab12d049425f6be 100644 --- a/test/test_sem_reagan.py +++ b/test/test_sem_reagan.py @@ -26,24 +26,24 @@ def test_tim2d(ctx_factory): knl = lp.make_kernel(ctx.devices[0], "[K] -> {[i,j,e,m,o,gi]: 0<=i,j,m,o<%d and 0<=e<K and 0<=gi<3}" % n, [ - "ur(a,b) := sum_float32(@o, D[a,o]*u[e,o,b])", - "us(a,b) := sum_float32(@o, D[b,o]*u[e,a,o])", + "ur(a,b) := sum(@o, D[a,o]*u[e,o,b])", + "us(a,b) := sum(@o, D[b,o]*u[e,a,o])", #"Gu(mat_entry,a,b) := G[mat_entry,e,m,j]*ur(m,j)", "Gux(a,b) := G$x[0,e,a,b]*ur(a,b)+G$x[1,e,a,b]*us(a,b)", "Guy(a,b) := G$y[1,e,a,b]*ur(a,b)+G$y[2,e,a,b]*us(a,b)", "lap[e,i,j] = " - " sum_float32(m, D[m,i]*Gux(m,j))" - "+ sum_float32(m, D[m,j]*Guy(i,m))" + " sum(m, D[m,i]*Gux(m,j))" + "+ sum(m, D[m,j]*Guy(i,m))" ], [ - lp.ArrayArg("u", dtype, shape=field_shape, order=order), - lp.ArrayArg("lap", dtype, shape=field_shape, order=order), - lp.ArrayArg("G", dtype, shape=(3,)+field_shape, order=order), + lp.GlobalArg("u", dtype, shape=field_shape, order=order), + lp.GlobalArg("lap", dtype, shape=field_shape, order=order), + lp.GlobalArg("G", dtype, shape=(3,)+field_shape, order=order), # lp.ConstantArrayArg("D", dtype, shape=(n, n), order=order), - lp.ArrayArg("D", dtype, shape=(n, n), order=order), + lp.GlobalArg("D", dtype, shape=(n, n), order=order), # lp.ImageArg("D", dtype, shape=(n, n)), lp.ValueArg("K", np.int32, approximately=1000), ], @@ -55,7 +55,7 @@ def test_tim2d(ctx_factory): knl = lp.tag_dimensions(knl, dict(i="l.0", j="l.1", e="g.0")) knl = lp.add_prefetch(knl, "D", ["m", "j", "i","o"]) - knl = lp.add_prefetch(knl, "u", ["i", "j", "o"]) + knl = lp.add_prefetch(knl, "u", ["i", "j", "o_ur", "o_us"]) knl = lp.precompute(knl, "ur(m,j)", np.float32, ["m", "j"]) knl = lp.precompute(knl, "us(i,m)", np.float32, ["i", "m"]) @@ -64,31 +64,23 @@ def test_tim2d(ctx_factory): knl = lp.precompute(knl, "Guy(i,m)", np.float32, ["i", "m"]) knl = lp.add_prefetch(knl, "G$x") + knl = lp.add_prefetch(knl, "G$y") knl = lp.tag_dimensions(knl, dict(o="unr")) knl = lp.tag_dimensions(knl, dict(m="unr")) - return knl + knl = lp.set_instruction_priority(knl, "D_fetch", 5) - def variant_1(knl): - # BUG? why can't the prefetch be in the j loop??! - print knl - from pudb import set_trace; set_trace() - knl = lp.precompute(knl, "ur", np.float32, ["a"]) - print knl - 1/0 - #knl = lp.precompute(knl, "us", np.float32, ["a"]) return knl for variant in [variant_orig]: - #for variant in [variant_1]: kernel_gen = lp.generate_loop_schedules(variant(knl)) kernel_gen = lp.check_kernels(kernel_gen, dict(K=1000)) K = 1000 lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen, - op_count=K*(n*n*n*2*2 + n*n*2*3 + n**3 * 2*2)/1e9, - op_label="GFlops", + op_count=[K*(n*n*n*2*2 + n*n*2*3 + n**3 * 2*2)/1e9], + op_label=["GFlops"], parameters={"K": K}, print_ref_code=True)