diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 840b2df418d694fcf4bce0c080282bc5ee782b7c..e657beecbc5453ae5b2390da5a958d2fc9a70771 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -60,6 +60,8 @@ def dump_space(ls): for dt in dim_type.names) +# {{{ make_slab + def make_slab(space, iname, start, stop): zero = isl.Aff.zero_on_domain(space) @@ -114,6 +116,8 @@ def make_slab_from_bound_pwaffs(space, iname, lbound, ubound): & iname_pwaff.le_set(ubound)) +# }}} + def iname_rel_aff(space, iname, rel, aff): """*aff*'s domain space is allowed to not match *space*.""" @@ -138,6 +142,8 @@ def iname_rel_aff(space, iname, rel, aff): raise ValueError("unknown value of 'rel': %s" % rel) +# {{{ static_*_of_pw_aff + def static_extremum_of_pw_aff(pw_aff, constants_only, set_method, what, context): if context is not None: context = isl.align_spaces(context, pw_aff.get_domain_space(), @@ -214,6 +220,10 @@ def static_value_of_pw_aff(pw_aff, constants_only, context=None): return static_extremum_of_pw_aff(pw_aff, constants_only, isl.PwAff.eq_set, "value", context) +# }}} + + +# {{{ duplicate_axes def duplicate_axes(isl_obj, duplicate_inames, new_inames): if isinstance(isl_obj, list): @@ -262,6 +272,8 @@ def duplicate_axes(isl_obj, duplicate_inames, new_inames): return moved_dims.intersect(more_dims) +# }}} + def is_nonnegative(expr, over_set): space = over_set.get_space() @@ -276,6 +288,8 @@ def is_nonnegative(expr, over_set): return over_set.intersect(expr_neg_set).is_empty() +# {{{ convexify + def convexify(domain): """Try a few ways to get *domain* to be a BasicSet, i.e. explicitly convex. @@ -312,6 +326,10 @@ def convexify(domain): print(" %s" % (isl.Set.from_basic_set(dbs).gist(domain))) raise NotImplementedError("Could not find convex representation of set") +# }}} + + +# {{{ boxify def boxify(cache_manager, domain, box_inames, context): var_dict = domain.get_var_dict(dim_type.set) @@ -357,6 +375,8 @@ def boxify(cache_manager, domain, box_inames, context): return convexify(result) +# }}} + def simplify_via_aff(expr): from loopy.symbolic import aff_from_expr, aff_to_expr, get_dependencies @@ -512,6 +532,8 @@ def dim_max_with_elimination(obj, idx): # }}} +# {{{ get_simple_strides + def get_simple_strides(bset, key_by="name"): """Return a dictionary from inames to strides in bset. Each stride is returned as a :class:`islpy.Val`. If no stride can be determined, the @@ -570,4 +592,29 @@ def get_simple_strides(bset, key_by="name"): return result +# }}} + + +# {{{{ find_max_of_pwaff_with_params + +def find_max_of_pwaff_with_params(pw_aff, n_allowed_params): + if n_allowed_params is None: + return pw_aff + + extra_dim_idx = pw_aff.dim(dim_type.param,) + pw_aff = pw_aff.add_dims(dim_type.param, 1) + + zero = isl.Aff.zero_on_domain(pw_aff.domain().space) + extra_dim = zero.set_coefficient_val(dim_type.param, extra_dim_idx, 1) + + pw_aff_set = pw_aff.eq_set(extra_dim) + + pw_aff_set = pw_aff_set.move_dims( + dim_type.set, 0, dim_type.param, n_allowed_params, + pw_aff_set.dim(dim_type.param) - n_allowed_params) + + return pw_aff_set.dim_max(pw_aff_set.dim(dim_type.set)-1) + +# }}} + # vim: foldmethod=marker diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 54236efca651e676745c4198cf509b019b95c084..88b18ff3866548ed674713c3fb86d1ff4cf12916 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -322,7 +322,13 @@ class SetOperationCacheManager: from loopy.isl_helpers import dim_max_with_elimination return self.op(set, "dim_max", dim_max_with_elimination, args) - def base_index_and_length(self, set, iname, context=None): + def base_index_and_length(self, set, iname, context=None, + n_allowed_params_in_length=None): + """ + :arg n_allowed_params_in_length: Simplifies the 'length' + argument so that only the first that many params + (in the domain of *set*) occur. + """ if not isinstance(iname, int): iname_to_dim = set.space.get_var_dict() idx = iname_to_dim[iname][1] @@ -336,7 +342,8 @@ class SetOperationCacheManager: from loopy.isl_helpers import ( static_max_of_pw_aff, static_min_of_pw_aff, - static_value_of_pw_aff) + static_value_of_pw_aff, + find_max_of_pwaff_with_params) from loopy.symbolic import pw_aff_to_expr # {{{ first: try to find static lower bound value @@ -351,11 +358,14 @@ class SetOperationCacheManager: if base_index_aff is not None: base_index = pw_aff_to_expr(base_index_aff) - size = pw_aff_to_expr(static_max_of_pw_aff( - upper_bound_pw_aff - base_index_aff + 1, constants_only=False, + length = find_max_of_pwaff_with_params( + upper_bound_pw_aff - base_index_aff + 1, + n_allowed_params_in_length) + length = pw_aff_to_expr(static_max_of_pw_aff( + length, constants_only=False, context=context)) - return base_index, size + return base_index, length # }}} @@ -367,11 +377,14 @@ class SetOperationCacheManager: base_index = pw_aff_to_expr(base_index_aff) - size = pw_aff_to_expr(static_max_of_pw_aff( - upper_bound_pw_aff - base_index_aff + 1, constants_only=False, + length = find_max_of_pwaff_with_params( + upper_bound_pw_aff - base_index_aff + 1, + n_allowed_params_in_length) + length = pw_aff_to_expr(static_max_of_pw_aff( + length, constants_only=False, context=context)) - return base_index, size + return base_index, length # }}} @@ -1083,4 +1096,43 @@ def guess_var_shape(kernel, var_name): # }}} + +# {{{ find_recursive_dependencies + +def find_recursive_dependencies(kernel, insn_ids): + queue = list(insn_ids) + + result = set(insn_ids) + + while queue: + new_queue = [] + + for insn_id in queue: + insn = kernel.id_to_insn[insn_id] + additionals = insn.depends_on - result + result.update(additionals) + new_queue.extend(additionals) + + queue = new_queue + + return result + +# }}} + + +# {{{ find_reverse_dependencies + +def find_reverse_dependencies(kernel, insn_ids): + """Finds a set of IDs of instructions that depend on one of the insn_ids. + + :arg insn_ids: a set of instruction IDs + """ + return frozenset( + insn.id + for insn in kernel.instructions + if insn.depends_on & insn_ids) + +# }}} + + # vim: foldmethod=marker diff --git a/loopy/transform/array_buffer_map.py b/loopy/transform/array_buffer_map.py index 38e35a94190bff49e537de873f0278d4d8b7b38b..3c7bfed43b9bd02a4be3d71b2317cee94da75b4b 100644 --- a/loopy/transform/array_buffer_map.py +++ b/loopy/transform/array_buffer_map.py @@ -162,9 +162,12 @@ def build_global_storage_to_sweep_map(kernel, access_descriptors, # {{{ compute storage bounds def find_var_base_indices_and_shape_from_inames( - domain, inames, cache_manager, context=None): + domain, inames, cache_manager, context=None, + n_allowed_params_in_shape=None): base_indices_and_sizes = [ - cache_manager.base_index_and_length(domain, iname, context) + cache_manager.base_index_and_length( + domain, iname, context, + n_allowed_params_in_length=n_allowed_params_in_shape) for iname in inames] return list(zip(*base_indices_and_sizes)) @@ -183,7 +186,8 @@ def compute_bounds(kernel, domain, stor2sweep, return find_var_base_indices_and_shape_from_inames( storage_domain, [saxis+"'" for saxis in storage_axis_names], - kernel.cache_manager, context=kernel.assumptions) + kernel.cache_manager, context=kernel.assumptions, + n_allowed_params_in_shape=stor2sweep.dim(isl.dim_type.param)) # }}} diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index db993b771d9088f0644c2406704f5b2e4c97ea89..5ab9dfab3c8ac0669c3e7eaf4091bb3ab4b0e2a2 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -136,7 +136,8 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): access_descriptors, array_base_map, storage_axis_names, storage_axis_sources, non1_storage_axis_names, - temporary_name, compute_insn_id, compute_read_variables): + temporary_name, compute_insn_id, compute_dep_id, + compute_read_variables): super(RuleInvocationReplacer, self).__init__(rule_mapping_context) self.subst_name = subst_name @@ -152,9 +153,10 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): self.temporary_name = temporary_name self.compute_insn_id = compute_insn_id + self.compute_dep_id = compute_dep_id self.compute_read_variables = compute_read_variables - self.compute_insn_deps = set() + self.compute_insn_depends_on = set() def map_substitution(self, name, tag, arguments, expn_state): if not ( @@ -222,6 +224,8 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): def map_kernel(self, kernel): new_insns = [] + excluded_insn_ids = set([self.compute_insn_id, self.compute_dep_id]) + for insn in kernel.instructions: self.replaced_something = False @@ -231,17 +235,17 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): insn = insn.copy( depends_on=( insn.depends_on - | frozenset([self.compute_insn_id]))) + | frozenset([self.compute_dep_id]))) for dep in insn.depends_on: - if dep == self.compute_insn_id: + if dep in excluded_insn_ids: continue dep_insn = kernel.id_to_insn[dep] if (frozenset(dep_insn.assignee_var_names()) & self.compute_read_variables): - self.compute_insn_deps.update( - insn.depends_on - set([self.compute_insn_id])) + self.compute_insn_depends_on.update( + insn.depends_on - excluded_insn_ids) new_insns.append(insn) @@ -790,6 +794,20 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, expression=compute_expression, # within_inames determined below ) + compute_dep_id = compute_insn_id + added_compute_insns = [compute_insn] + + if temporary_scope == temp_var_scope.GLOBAL: + barrier_insn_id = kernel.make_unique_instruction_id( + based_on=c_subst_name+"_b") + from loopy.kernel.instruction import BarrierInstruction + barrier_insn = BarrierInstruction( + id=barrier_insn_id, + depends_on=frozenset([compute_insn_id]), + kind="global") + compute_dep_id = barrier_insn_id + + added_compute_insns.append(barrier_insn) # }}} @@ -803,12 +821,12 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, access_descriptors, abm, storage_axis_names, storage_axis_sources, non1_storage_axis_names, - temporary_name, compute_insn_id, + temporary_name, compute_insn_id, compute_dep_id, compute_read_variables=get_dependencies(expander(compute_expression))) kernel = invr.map_kernel(kernel) kernel = kernel.copy( - instructions=[compute_insn] + kernel.instructions) + instructions=added_compute_insns + kernel.instructions) kernel = rule_mapping_context.finish_kernel(kernel) # }}} @@ -817,13 +835,43 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, kernel = kernel.copy( instructions=[ - insn.copy(depends_on=frozenset(invr.compute_insn_deps)) + insn.copy(depends_on=frozenset(invr.compute_insn_depends_on)) if insn.id == compute_insn_id else insn for insn in kernel.instructions]) # }}} + # {{{ propagate storage iname subst to dependencies of compute instructions + + from loopy.kernel.tools import find_recursive_dependencies + compute_deps = find_recursive_dependencies( + kernel, frozenset([compute_insn_id])) + + # FIXME: Need to verify that there are no outside dependencies + # on compute_deps + + prior_storage_axis_names = frozenset(storage_axis_subst_dict) + + new_insns = [] + for insn in kernel.instructions: + if (insn.id in compute_deps + and insn.within_inames & prior_storage_axis_names): + insn = (insn + .with_transformed_expressions( + lambda expr: expr_subst_map(expr, kernel, insn)) + .copy(within_inames=frozenset( + storage_axis_subst_dict.get(iname, var(iname)).name + for iname in insn.within_inames))) + + new_insns.append(insn) + else: + new_insns.append(insn) + + kernel = kernel.copy(instructions=new_insns) + + # }}} + # {{{ determine inames for compute insn if precompute_outer_inames is None: diff --git a/test/test_reduction.py b/test/test_reduction.py index 68f6242440a14eeeb2144762f0d5175f2135ffa2..b78509b6318a984d117d00b1a6854d9611db80d1 100644 --- a/test/test_reduction.py +++ b/test/test_reduction.py @@ -214,18 +214,19 @@ def test_local_parallel_reduction(ctx_factory, size): lp.auto_test_vs_ref(ref_knl, ctx, knl) -# FIXME: Make me a test @pytest.mark.parametrize("size", [10000]) -def no_test_global_parallel_reduction(ctx_factory, size): - ctx = ctx_factory() - queue = cl.CommandQueue(ctx) +def test_global_parallel_reduction(ctx_factory, size): + # ctx = ctx_factory() + # queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{[i]: 0 <= i < n }", """ - <> key = make_uint2(i, 324830944) {inames=i} - <> ctr = make_uint4(0, 1, 2, 3) {inames=i,id=init_ctr} - <> vals, ctr = philox4x32_f32(ctr, key) {dep=init_ctr} + for i + <> key = make_uint2(i, 324830944) {inames=i} + <> ctr = make_uint4(0, 1, 2, 3) {inames=i,id=init_ctr} + <> vals, ctr = philox4x32_f32(ctr, key) {dep=init_ctr} + end z = sum(i, vals.s0 + vals.s1 + vals.s2 + vals.s3) """) @@ -238,12 +239,11 @@ def no_test_global_parallel_reduction(ctx_factory, size): knl = lp.split_reduction_inward(knl, "i_inner_outer") from loopy.transform.data import reduction_arg_to_subst_rule knl = reduction_arg_to_subst_rule(knl, "i_outer") - knl = lp.precompute(knl, "red_i_outer_arg", "i_outer") - print(knl) - 1/0 + knl = lp.precompute(knl, "red_i_outer_arg", "i_outer", + temporary_scope=lp.temp_var_scope.GLOBAL) knl = lp.realize_reduction(knl) - evt, (z,) = knl(queue, n=size) + #evt, (z,) = knl(queue, n=size) #lp.auto_test_vs_ref(ref_knl, ctx, knl)