diff --git a/loopy/schedule/schedule_checker/__init__.py b/loopy/schedule/schedule_checker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8062070f8f1c923e6eae2ba75320ff77d20dfe --- /dev/null +++ b/loopy/schedule/schedule_checker/__init__.py @@ -0,0 +1,281 @@ + + +def get_statement_pair_dependency_sets_from_legacy_knl(knl): + """Return a list of :class:`StatementPairDependySet` instances created + for a :class:`loopy.LoopKernel` containing legacy depencencies. Create + the new dependencies according to the following rules. (1) If + a dependency exists between ``insn0`` and ``insn1``, create the dependnecy + ``SAME(SNC)`` where ``SNC`` is the set of non-concurrent inames used + by both ``insn0`` and ``insn1``, and ``SAME`` is the relationship specified + by the ``SAME`` attribute of :class:`DependencyType`. (2) For each subset + of non-concurrent inames used by any instruction, find the set of all + instructions using those inames, create a directed graph with these + instructions as nodes and edges representing a 'happens before' + relationship specfied by each dependency, find the sources and sinks within + this graph, and connect each sink to each source (sink happens before + source) with a ``PRIOR(SNC)`` dependency, where ``PRIOR`` is the + relationship specified by the ``PRIOR`` attribute of + :class:`DependencyType`. + + """ + + from schedule_checker.dependency import ( + create_dependencies_from_legacy_knl, + ) + + # Preprocess if not already preprocessed + # note that kernels must always be preprocessed before scheduling + from loopy.kernel import KernelState + if knl.state < KernelState.PREPROCESSED: + from loopy import preprocess_kernel + preprocessed_knl = preprocess_kernel(knl) + else: + preprocessed_knl = knl + + # Create StatementPairDependencySet(s) from kernel dependencies + + return create_dependencies_from_legacy_knl(preprocessed_knl) + + +# TODO create a set of broken kernels to test against +# (small kernels to test a specific case) +# TODO work on granularity of encapsulation, encapsulate some of this in +# separate functions +def check_schedule_validity( + knl, + deps_and_domains, + schedule_items, + prohibited_var_names=set(), + verbose=False, + _use_scheduled_kernel_to_obtain_loop_priority=False): + + from schedule_checker.dependency import ( + create_dependency_constraint, + ) + from schedule_checker.schedule import LexSchedule + from schedule_checker.lexicographic_order_map import ( + get_statement_ordering_map, + ) + from schedule_checker.sched_check_utils import ( + prettier_map_string, + ) + + # Preprocess if not already preprocessed + # note that kernels must always be preprocessed before scheduling + from loopy.kernel import KernelState + if knl.state < KernelState.PREPROCESSED: + from loopy import preprocess_kernel + preprocessed_knl = preprocess_kernel(knl) + else: + preprocessed_knl = knl + + if not prohibited_var_names: + prohibited_var_names = preprocessed_knl.all_inames() + + if verbose: + print("="*80) + print("StatementDependencies w/domains:") + for dep_set in deps_and_domains: + print(dep_set) + print(dep_set.dom_before) + print(dep_set.dom_after) + + # Print kernel info ------------------------------------------------------ + print("="*80) + #print("Kernel:") + #print(scheduled_knl) + #from loopy import generate_code_v2 + #print(generate_code_v2(scheduled_knl).device_code()) + print("="*80) + #print("Iname tags: %s" % (scheduled_knl.iname_to_tags)) + print("="*80) + print("Loopy schedule:") + for sched_item in schedule_items: + print(sched_item) + #print("scheduled iname order:") + #print(sched_iname_order) + + print("="*80) + print("Looping through dep pairs...") + + # For each dependency, create+test schedule containing pair of insns------ + sched_is_valid = True + for statement_pair_dep_set in deps_and_domains: + s_before = statement_pair_dep_set.statement_before + s_after = statement_pair_dep_set.statement_after + dom_before = statement_pair_dep_set.dom_before + dom_after = statement_pair_dep_set.dom_after + + if verbose: + print("="*80) + print("statement dep set:") + print(statement_pair_dep_set) + print("dom_before:", dom_before) + print("dom_after:", dom_after) + + # Create a mapping of {statement instance: lex point} + # including only instructions involved in this dependency + sched = LexSchedule( + preprocessed_knl, + schedule_items, + s_before.insn_id, + s_after.insn_id, + prohibited_var_names=prohibited_var_names, + ) + + #print("-"*80) + #print("LexSchedule before processing:") + #print(sched) + + lp_insn_id_to_lex_sched_id = sched.loopy_insn_id_to_lex_sched_id() + if verbose: + print("-"*80) + print("LexSchedule with inames added:") + print(sched) + print("dict{lp insn id : sched sid int}:") + print(lp_insn_id_to_lex_sched_id) + + # Get an isl map representing the LexSchedule; + # this requires the iname domains + + sched_map_symbolic_before, sched_map_symbolic_after = \ + sched.create_symbolic_isl_maps( + dom_before, + dom_after, + ) + + if verbose: + print("dom_before:\n", dom_before) + print("dom_after:\n", dom_after) + print("LexSchedule after creating symbolic isl map:") + print(sched) + print("LexSched:") + print(prettier_map_string(sched_map_symbolic_before)) + print(prettier_map_string(sched_map_symbolic_after)) + #print("-"*80) + + # get map representing lexicographic ordering + lex_order_map_symbolic = sched.get_lex_order_map_for_symbolic_sched() + """ + if verbose: + print("lex order map symbolic:") + print(prettier_map_string(lex_order_map_symbolic)) + print("space (lex time -> lex time):") + print(lex_order_map_symbolic.space) + print("-"*80) + """ + + # create statement instance ordering, + # maps each statement instance to all statement instances occuring later + sio = get_statement_ordering_map( + sched_map_symbolic_before, + sched_map_symbolic_after, + lex_order_map_symbolic, + ) + + if verbose: + print("statement instance ordering:") + print(prettier_map_string(sio)) + print("SIO space (statement instances -> statement instances):") + print(sio.space) + print("-"*80) + + # create a map representing constraints from the dependency, + # maps statement instance to all statement instances that must occur later + constraint_map = create_dependency_constraint( + statement_pair_dep_set, + knl.loop_priority, + lp_insn_id_to_lex_sched_id, + sched.unused_param_name, + sched.statement_var_name, + ) + # TODO figure out how to keep a consistent lp_insn_id_to_lex_sched_id map + # when dependency creation is separate from schedule checking + + # align constraint map spaces to match sio so we can compare them + if verbose: + print("constraint map space (before aligning):") + print(constraint_map.space) + + # align params + aligned_constraint_map = constraint_map.align_params(sio.space) + + # align in_ dims + import islpy as isl + from schedule_checker.sched_check_utils import ( + reorder_dims_by_name, + ) + sio_in_names = sio.space.get_var_names(isl.dim_type.in_) + aligned_constraint_map = reorder_dims_by_name( + aligned_constraint_map, + isl.dim_type.in_, + sio_in_names, + add_missing=False, + new_names_are_permutation_only=True, + ) + + # align out dims + sio_out_names = sio.space.get_var_names(isl.dim_type.out) + aligned_constraint_map = reorder_dims_by_name( + aligned_constraint_map, + isl.dim_type.out, + sio_out_names, + add_missing=False, + new_names_are_permutation_only=True, + ) + + if verbose: + print("constraint map space (after aligning):") + print(aligned_constraint_map.space) + print("constraint map:") + print(prettier_map_string(aligned_constraint_map)) + + assert aligned_constraint_map.space == sio.space + assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.in_) + == sio.space.get_var_names(isl.dim_type.in_)) + assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.out) + == sio.space.get_var_names(isl.dim_type.out)) + assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.param) + == sio.space.get_var_names(isl.dim_type.param)) + + if not aligned_constraint_map.is_subset(sio): + + sched_is_valid = False + + if verbose: + print("================ constraint check failure =================") + print("constraint map not subset of SIO") + print("dependency:") + print(statement_pair_dep_set) + print("statement instance ordering:") + print(prettier_map_string(sio)) + print("constraint_map.gist(sio):") + print(aligned_constraint_map.gist(sio)) + print("sio.gist(constraint_map)") + print(sio.gist(aligned_constraint_map)) + print("loop priority known:") + print(preprocessed_knl.loop_priority) + """ + from schedule_checker.sched_check_utils import ( + get_concurrent_inames, + ) + conc_inames, non_conc_inames = get_concurrent_inames(scheduled_knl) + print("concurrent inames:", conc_inames) + print("sequential inames:", non_conc_inames) + print("constraint map space (stmt instances -> stmt instances):") + print(aligned_constraint_map.space) + print("SIO space (statement instances -> statement instances):") + print(sio.space) + print("constraint map:") + print(prettier_map_string(aligned_constraint_map)) + print("statement instance ordering:") + print(prettier_map_string(sio)) + print("{insn id -> sched sid int} dict:") + print(lp_insn_id_to_lex_sched_id) + """ + print("===========================================================") + + return sched_is_valid diff --git a/loopy/schedule/schedule_checker/dependency.py b/loopy/schedule/schedule_checker/dependency.py new file mode 100644 index 0000000000000000000000000000000000000000..a780a036d909953cf285c2da60c55ef953a4a97e --- /dev/null +++ b/loopy/schedule/schedule_checker/dependency.py @@ -0,0 +1,924 @@ +import islpy as isl + + +class DependencyType: + """Strings specifying a particular type of dependency relationship. + + .. attribute:: SAME + + A :class:`str` specifying the following dependency relationship: + + If ``S = {i, j, ...}`` is a set of inames used in both statements + ``insn0`` and ``insn1``, and ``{i', j', ...}`` represent the values + of the inames in ``insn0``, and ``{i, j, ...}`` represent the + values of the inames in ``insn1``, then the dependency + ``insn0 happens before insn1 iff SAME({i, j})`` specifies that + ``insn0 happens before insn1 iff {i' = i and j' = j and ...}``. + Note that ``SAME({}) = True``. + + .. attribute:: PRIOR + + A :class:`str` specifying the following dependency relationship: + + If ``S = {i, j, k, ...}`` is a set of inames used in both statements + ``insn0`` and ``insn1``, and ``{i', j', k', ...}`` represent the values + of the inames in ``insn0``, and ``{i, j, k, ...}`` represent the + values of the inames in ``insn1``, then the dependency + ``insn0 happens before insn1 iff PRIOR({i, j, k})`` specifies one of + two possibilities, depending on whether the loop nest ordering is + known. If the loop nest ordering is unknown, then + ``insn0 happens before insn1 iff {i' < i and j' < j and k' < k ...}``. + If the loop nest ordering is known, the condition becomes + ``{i', j', k', ...}`` is lexicographically less than ``{i, j, k, ...}``, + i.e., ``i' < i or (i' = i and j' < j) or (i' = i and j' = j and k' < k) ...``. + + """ + + SAME = "same" + PRIOR = "prior" + + +class StatementPairDependencySet(object): + """A set of dependencies between two statements. + + .. attribute:: statement_before + + A :class:`LexScheduleStatement` depended on by statement_after. + + .. attribute:: statement_after + + A :class:`LexScheduleStatement` which depends on statement_before. + + .. attribute:: deps + + A :class:`dict` mapping instances of :class:`DependencyType` to + the Loopy kernel inames involved in that particular + dependency relationship. + + .. attribute:: dom_before + + A :class:`islpy.BasicSet` representing the domain for the + dependee statement. + + .. attribute:: dom_after + + A :class:`islpy.BasicSet` representing the domain for the + dependee statement. + + """ + + def __init__( + self, + statement_before, + statement_after, + deps, # {dep_type: iname_set} + dom_before=None, + dom_after=None, + ): + self.statement_before = statement_before + self.statement_after = statement_after + self.deps = deps + self.dom_before = dom_before + self.dom_after = dom_after + + def __eq__(self, other): + return ( + self.statement_before == other.statement_before + and self.statement_after == other.statement_after + and self.deps == other.deps + and self.dom_before == other.dom_before + and self.dom_after == other.dom_after + ) + + def __lt__(self, other): + return self.__hash__() < other.__hash__() + + def __hash__(self): + return hash(repr(self)) + + def update_persistent_hash(self, key_hash, key_builder): + """Custom hash computation function for use with + :class:`pytools.persistent_dict.PersistentDict`. + """ + + key_builder.rec(key_hash, self.statement_before) + key_builder.rec(key_hash, self.statement_after) + key_builder.rec(key_hash, self.deps) + key_builder.rec(key_hash, self.dom_before) + key_builder.rec(key_hash, self.dom_after) + + def __str__(self): + result = "%s --before->\n%s iff\n " % ( + self.statement_before, self.statement_after) + return result + " and\n ".join( + ["(%s : %s)" % (dep_type, inames) + for dep_type, inames in self.deps.items()]) + + +def create_elementwise_comparison_conjunction_set( + names0, names1, islvars, op="eq"): + """Create a set constrained by the conjunction of conditions comparing + `names0` to `names1`. + + .. arg names0: A list of :class:`str` representing variable names. + + .. arg names1: A list of :class:`str` representing variable names. + + .. arg islvars: A dictionary from variable names to :class:`PwAff` + instances that represent each of the variables + (islvars may be produced by `islpy.make_zero_and_vars`). The key + '0' is also include and represents a :class:`PwAff` zero constant. + + .. arg op: A :class:`str` describing the operator to use when creating + the set constraints. Options: `eq` for `=`, `lt` for `<` + + .. return: A set involving `islvars` cosntrained by the constraints + `{names0[0] names1[0] and names0[1] names1[1] and ...}`. + + """ + + # initialize set with constraint that is always true + conj_set = islvars[0].eq_set(islvars[0]) + for n0, n1 in zip(names0, names1): + if op == "eq": + conj_set = conj_set & islvars[n0].eq_set(islvars[n1]) + elif op == "lt": + conj_set = conj_set & islvars[n0].lt_set(islvars[n1]) + + return conj_set + + +def _convert_constraint_set_to_map(constraint_set, mv_count, src_position=None): + dim_type = isl.dim_type + constraint_map = isl.Map.from_domain(constraint_set) + if src_position: + return constraint_map.move_dims( + dim_type.out, 0, dim_type.in_, src_position, mv_count) + else: + return constraint_map.move_dims( + dim_type.out, 0, dim_type.in_, mv_count, mv_count) + + +def create_dependency_constraint( + statement_dep_set, + loop_priorities, + insn_id_to_int, + unused_param_name, + statement_var_name, + statement_var_pose=0, + dom_inames_ordered_before=None, + dom_inames_ordered_after=None, + ): + """Create a statement dependency constraint represented as a map from + each statement instance to statement instances that must occur later, + i.e., ``{[s'=0, i', j'] -> [s=1, i, j] : condition on {i', j', i, j}}`` + indicates that statement ``0`` comes before statment ``1`` when the + specified condition on inames ``i',j',i,j`` is met. ``i'`` and ``j'`` + are the values of inames ``i`` and ``j`` in first statement instance. + + .. arg statement_dep_set: A :class:`StatementPairDependencySet` describing + the dependency relationship between the two statements. + + .. arg loop_priorities: A list of tuples from the ``loop_priority`` + attribute of :class:`loopy.LoopKernel` specifying the loop nest + ordering rules. + + .. arg insn_id_to_int: A :class:`dict` mapping insn_id to int_id, where + 'insn_id' and 'int_id' refer to the 'insn_id' and 'int_id' attributes + of :class:`LexScheduleStatement`. + + .. arg unused_param_name: A :class:`str` that specifies the name of a + dummy isl parameter assigned to variables in domain elements of the + isl map that represent inames unused in a particular statement + instance. The domain space of the generated isl map will have a + dimension for every iname used in any statement instance found in + the program ordering. An element in the domain of this map may + represent a statement instance that does not lie within iname x, but + will still need to assign a value to the x domain variable. In this + case, the parameter unused_param_name is is assigned to x. + + .. arg statement_var_name: A :class:`str` specifying the name of the + isl variable used to represent the unique :class:`int` statement id. + + .. arg statement_var_pose: A :class:`int` specifying which position in the + statement instance tuples holds the dimension representing the + statement id. Defaults to ``0``. + + .. arg all_dom_inames_ordered_before: A :class:`list` of :class:`str` + specifying an order for the dimensions representing dependee inames. + + .. arg all_dom_inames_ordered_after: A :class:`list` of :class:`str` + specifying an order for the dimensions representing depender inames. + + .. return: An :class:`islpy.Map` mapping each statement instance to all + statement instances that must occur later according to the constraints. + + """ + + from schedule_checker.sched_check_utils import ( + make_islvars_with_marker, + append_apostrophes, + add_dims_to_isl_set, + reorder_dims_by_name, + create_new_isl_set_with_primes, + ) + # This function uses the dependency given to create the following constraint: + # Statement [s,i,j] comes before statement [s',i',j'] iff + + from schedule_checker.sched_check_utils import ( + list_var_names_in_isl_sets, + ) + if dom_inames_ordered_before is None: + dom_inames_ordered_before = list_var_names_in_isl_sets( + [statement_dep_set.dom_before]) + if dom_inames_ordered_after is None: + dom_inames_ordered_after = list_var_names_in_isl_sets( + [statement_dep_set.dom_after]) + + # create some (ordered) isl vars to use, e.g., {s, i, j, s', i', j'} + islvars = make_islvars_with_marker( + var_names_needing_marker=[statement_var_name]+dom_inames_ordered_before, + other_var_names=[statement_var_name]+dom_inames_ordered_after, + param_names=[unused_param_name], + marker="'", + ) + statement_var_name_prime = statement_var_name+"'" + + # get (ordered) list of unused before/after inames + inames_before_unused = [] + for iname in dom_inames_ordered_before: + if iname not in statement_dep_set.dom_before.get_var_names(isl.dim_type.out): + inames_before_unused.append(iname + "'") + inames_after_unused = [] + for iname in dom_inames_ordered_after: + if iname not in statement_dep_set.dom_after.get_var_names(isl.dim_type.out): + inames_after_unused.append(iname) + + # TODO are there ever unused inames now that we're separating the in/out spaces? + if inames_before_unused or inames_after_unused: + assert False + + # initialize constraints to False + # this will disappear as soon as we add a constraint + all_constraints_set = islvars[0].eq_set(islvars[0] + 1) + + # for each (dep_type, inames) pair, create 'happens before' constraint, + # all_constraints_set will be the union of all these constraints + dt = DependencyType + for dep_type, inames in statement_dep_set.deps.items(): + # need to put inames in a list so that order of inames and inames' + # matches when calling create_elementwise_comparison_conj... + if not isinstance(inames, list): + inames_list = list(inames) + else: + inames_list = inames[:] + inames_prime = append_apostrophes(inames_list) # e.g., [j', k'] + + if dep_type == dt.SAME: + constraint_set = create_elementwise_comparison_conjunction_set( + inames_prime, inames_list, islvars, op="eq") + elif dep_type == dt.PRIOR: + + priority_known = False + # if nesting info is provided: + if loop_priorities: + # assumes all loop_priority tuples are consistent + + # with multiple priority tuples, determine whether the combined + # info they contain can give us a single, full proiritization, + # e.g., if prios={(a, b), (b, c), (c, d, e)}, then we know + # a -> b -> c -> d -> e + + # remove irrelevant inames from priority tuples (because we're + # about to perform a costly operation on remaining tuples) + relevant_priorities = set() + for p_tuple in loop_priorities: + new_tuple = [iname for iname in p_tuple if iname in inames_list] + # empty tuples and single tuples don't help us define + # a nesting, so ignore them (if we're dealing with a single + # iname, priorities will be ignored later anyway) + if len(new_tuple) > 1: + relevant_priorities.add(tuple(new_tuple)) + + # create a mapping from each iname to inames that must be + # nested inside that iname + nested_inside = {} + for iname in inames_list: + comes_after_iname = set() + for p_tuple in relevant_priorities: + if iname in p_tuple: + comes_after_iname.update([ + iname for iname in + p_tuple[p_tuple.index(iname)+1:]]) + nested_inside[iname] = comes_after_iname + + from schedule_checker.sched_check_utils import ( + get_orderings_of_length_n) + # get all orderings that are explicitly allowed by priorities + orders = get_orderings_of_length_n( + nested_inside, + required_length=len(inames_list), + #return_first_found=True, + return_first_found=False, # slower; allows priorities test below + ) + + if orders: + # test for invalid priorities (includes cycles) + if len(orders) != 1: + raise ValueError( + "create_dependency_constriant encountered invalid " + "priorities %s" + % (loop_priorities)) + priority_known = True + priority_tuple = orders.pop() + + # if only one loop, we know the priority + if not priority_known and len(inames_list) == 1: + priority_tuple = tuple(inames_list) + priority_known = True + + if priority_known: + # PRIOR requires statement_before complete previous iterations + # of loops before statement_after completes current iteration + # according to loop nest order + inames_list_nest_ordered = [ + iname for iname in priority_tuple + if iname in inames_list] + inames_list_nest_ordered_prime = append_apostrophes( + inames_list_nest_ordered) + if set(inames_list_nest_ordered) != set(inames_list): + # TODO could this happen? + assert False + + from schedule_checker.lexicographic_order_map import ( + get_lex_order_constraint + ) + # TODO handle case where inames list is empty + constraint_set = get_lex_order_constraint( + islvars, + inames_list_nest_ordered_prime, + inames_list_nest_ordered, + ) + else: # priority not known + # PRIOR requires upper left quadrant happen before: + constraint_set = create_elementwise_comparison_conjunction_set( + inames_prime, inames_list, islvars, op="lt") + + # TODO remove, this shouldn't happen anymore + # set unused vars == unused dummy param + for iname in inames_before_unused+inames_after_unused: + constraint_set = constraint_set & islvars[iname].eq_set( + islvars[unused_param_name]) + + # set statement_var_name == statement # + s_before_int = insn_id_to_int[statement_dep_set.statement_before.insn_id] + s_after_int = insn_id_to_int[statement_dep_set.statement_after.insn_id] + constraint_set = constraint_set & islvars[statement_var_name_prime].eq_set( + islvars[0]+s_before_int) + constraint_set = constraint_set & islvars[statement_var_name].eq_set( + islvars[0]+s_after_int) + + # union this constraint_set with all_constraints_set + all_constraints_set = all_constraints_set | constraint_set + + # convert constraint set to map + all_constraints_map = _convert_constraint_set_to_map( + all_constraints_set, + mv_count=len(dom_inames_ordered_after)+1, # +1 for statement var + src_position=len(dom_inames_ordered_before)+1, # +1 for statement var + ) + + # now apply domain sets to constraint variables + + # add statement variable to doms to enable intersection + range_to_intersect = add_dims_to_isl_set( + statement_dep_set.dom_after, isl.dim_type.out, + [statement_var_name], statement_var_pose) + domain_constraint_set = create_new_isl_set_with_primes( + statement_dep_set.dom_before) + domain_to_intersect = add_dims_to_isl_set( + domain_constraint_set, isl.dim_type.out, + [statement_var_name_prime], statement_var_pose) + + # insert inames missing from doms to enable intersection + domain_to_intersect = reorder_dims_by_name( + domain_to_intersect, isl.dim_type.out, + append_apostrophes([statement_var_name] + dom_inames_ordered_before), + add_missing=True) + range_to_intersect = reorder_dims_by_name( + range_to_intersect, + isl.dim_type.out, + [statement_var_name] + dom_inames_ordered_after, + add_missing=True) + + # intersect doms + map_with_loop_domain_constraints = all_constraints_map.intersect_domain( + domain_to_intersect).intersect_range(range_to_intersect) + + return map_with_loop_domain_constraints + + +# TODO no longer used, remove +def _create_5pt_stencil_dependency_constraint( + dom_before_constraint_set, + dom_after_constraint_set, + sid_before, + sid_after, + space_iname, + time_iname, + unused_param_name, + statement_var_name, + statement_var_pose=0, + all_dom_inames_ordered=None, + ): + + from schedule_checker.sched_check_utils import ( + make_islvars_with_marker, + append_apostrophes, + add_dims_to_isl_set, + reorder_dims_by_name, + create_new_isl_set_with_primes, + ) + # This function uses the dependency given to create the following constraint: + # Statement [s,i,j] comes before statement [s',i',j'] iff + + from schedule_checker.sched_check_utils import ( + list_var_names_in_isl_sets, + ) + if all_dom_inames_ordered is None: + all_dom_inames_ordered = list_var_names_in_isl_sets( + [dom_before_constraint_set, dom_after_constraint_set]) + + # create some (ordered) isl vars to use, e.g., {s, i, j, s', i', j'} + islvars = make_islvars_with_marker( + var_names_needing_marker=[statement_var_name]+all_dom_inames_ordered, + other_var_names=[statement_var_name]+all_dom_inames_ordered, + param_names=[unused_param_name], + marker="'", + ) + statement_var_name_prime = statement_var_name+"'" + + # get (ordered) list of unused before/after inames + inames_before_unused = [] + for iname in all_dom_inames_ordered: + if iname not in dom_before_constraint_set.get_var_names(isl.dim_type.out): + inames_before_unused.append(iname + "'") + inames_after_unused = [] + for iname in all_dom_inames_ordered: + if iname not in dom_after_constraint_set.get_var_names(isl.dim_type.out): + inames_after_unused.append(iname) + + # initialize constraints to False + # this will disappear as soon as we add a constraint + #all_constraints_set = islvars[0].eq_set(islvars[0] + 1) + + space_iname_prime = space_iname + "'" + time_iname_prime = time_iname + "'" + one = islvars[0] + 1 + two = islvars[0] + 2 + # global: + """ + constraint_set = ( + islvars[time_iname_prime].gt_set(islvars[time_iname]) & + ( + (islvars[space_iname_prime]-two).lt_set(islvars[space_iname]) & + islvars[space_iname].lt_set(islvars[space_iname_prime]+two) + ) + | + islvars[time_iname_prime].gt_set(islvars[time_iname] + one) & + islvars[space_iname].eq_set(islvars[space_iname_prime]) + ) + """ + # local dep: + constraint_set = ( + islvars[time_iname].eq_set(islvars[time_iname_prime] + one) & ( + (islvars[space_iname]-two).lt_set(islvars[space_iname_prime]) & + islvars[space_iname_prime].lt_set(islvars[space_iname]+two)) + | + (islvars[time_iname].eq_set(islvars[time_iname_prime] + two) + & islvars[space_iname_prime].eq_set(islvars[space_iname])) + ) + + # set unused vars == unused dummy param + for iname in inames_before_unused+inames_after_unused: + constraint_set = constraint_set & islvars[iname].eq_set( + islvars[unused_param_name]) + + # set statement_var_name == statement # + constraint_set = constraint_set & islvars[statement_var_name_prime].eq_set( + islvars[0]+sid_before) + constraint_set = constraint_set & islvars[statement_var_name].eq_set( + islvars[0]+sid_after) + + # convert constraint set to map + all_constraints_map = _convert_constraint_set_to_map( + constraint_set, len(all_dom_inames_ordered) + 1) # +1 for statement var + + # now apply domain sets to constraint variables + + # add statement variable to doms to enable intersection + range_to_intersect = add_dims_to_isl_set( + dom_after_constraint_set, isl.dim_type.out, + [statement_var_name], statement_var_pose) + domain_constraint_set = create_new_isl_set_with_primes(dom_before_constraint_set) + domain_to_intersect = add_dims_to_isl_set( + domain_constraint_set, isl.dim_type.out, + [statement_var_name_prime], statement_var_pose) + + # insert inames missing from doms to enable intersection + domain_to_intersect = reorder_dims_by_name( + domain_to_intersect, isl.dim_type.out, + append_apostrophes([statement_var_name] + all_dom_inames_ordered), + add_missing=True) + range_to_intersect = reorder_dims_by_name( + range_to_intersect, + isl.dim_type.out, + [statement_var_name] + all_dom_inames_ordered, + add_missing=True) + + # intersect doms + map_with_loop_domain_constraints = all_constraints_map.intersect_domain( + domain_to_intersect).intersect_range(range_to_intersect) + + return map_with_loop_domain_constraints + + +def create_arbitrary_dependency_constraint( + constraint_str, + dom_before_constraint_set, + dom_after_constraint_set, + sid_before, + sid_after, + unused_param_name, + statement_var_name, + statement_var_pose=0, + all_dom_inames_ordered=None, + ): + + # TODO test after switching primes to before vars + + from schedule_checker.sched_check_utils import ( + make_islvars_with_marker, + #append_apostrophes, + append_marker_to_strings, + add_dims_to_isl_set, + reorder_dims_by_name, + create_new_isl_set_with_primes, + ) + # This function uses the constraint given to create the following map: + # Statement [s,i,j] comes before statement [s',i',j'] iff + + from schedule_checker.sched_check_utils import ( + list_var_names_in_isl_sets, + ) + if all_dom_inames_ordered is None: + all_dom_inames_ordered = list_var_names_in_isl_sets( + [dom_before_constraint_set, dom_after_constraint_set]) + + # create some (ordered) isl vars to use, e.g., {s, i, j, s', i', j'} + islvars = make_islvars_with_marker( + var_names_needing_marker=[statement_var_name]+all_dom_inames_ordered, + other_var_names=[statement_var_name]+all_dom_inames_ordered, + param_names=[unused_param_name], + marker="p", + ) # TODO figure out before/after notation + #statement_var_name_prime = statement_var_name+"'" + statement_var_name_prime = statement_var_name+"p" + # TODO figure out before/after notation + + # get (ordered) list of unused before/after inames + inames_before_unused = [] + for iname in all_dom_inames_ordered: + if iname not in dom_before_constraint_set.get_var_names(isl.dim_type.out): + inames_before_unused.append(iname + "p") + inames_after_unused = [] + for iname in all_dom_inames_ordered: + if iname not in dom_after_constraint_set.get_var_names(isl.dim_type.out): + #inames_after_unused.append(iname + "'") + inames_after_unused.append(iname) + # TODO figure out before/after notation + + # initialize constraints to False + # this will disappear as soon as we add a constraint + all_constraints_set = islvars[0].eq_set(islvars[0] + 1) + space = all_constraints_set.space + from pymbolic import parse + from loopy.symbolic import aff_from_expr + + or_constraint_strs = constraint_str.split("or") + + def _quant(s): + return "(" + s + ")" + + def _diff(s0, s1): + return _quant(s0) + "-" + _quant(s1) + + for or_constraint_str in or_constraint_strs: + and_constraint_strs = or_constraint_str.split("and") + #conj_constraint = islvars[0].eq_set(islvars[0]) # init to true + conj_constraint = isl.BasicSet.universe(space) + for cons_str in and_constraint_strs: + if "<=" in cons_str: + lhs, rhs = cons_str.split("<=") + conj_constraint = conj_constraint.add_constraint( + isl.Constraint.inequality_from_aff( + aff_from_expr(space, parse(_diff(rhs, lhs))))) + # TODO something more robust than this string meddling^ + elif ">=" in cons_str: + lhs, rhs = cons_str.split(">=") + conj_constraint = conj_constraint.add_constraint( + isl.Constraint.inequality_from_aff( + aff_from_expr(space, parse(_diff(lhs, rhs))))) + elif "<" in cons_str: + lhs, rhs = cons_str.split("<") + conj_constraint = conj_constraint.add_constraint( + isl.Constraint.inequality_from_aff( + aff_from_expr(space, parse(_diff(rhs, lhs) + "- 1")))) + elif ">" in cons_str: + lhs, rhs = cons_str.split(">") + conj_constraint = conj_constraint.add_constraint( + isl.Constraint.inequality_from_aff( + aff_from_expr(space, parse(_diff(lhs, rhs) + "- 1")))) + elif "=" in cons_str: + lhs, rhs = cons_str.split("=") + conj_constraint = conj_constraint.add_constraint( + isl.Constraint.equality_from_aff( + aff_from_expr(space, parse(_diff(lhs, rhs))))) + else: + 1/0 + all_constraints_set = all_constraints_set | conj_constraint + + # set unused vars == unused dummy param + for iname in inames_before_unused+inames_after_unused: + all_constraints_set = all_constraints_set & islvars[iname].eq_set( + islvars[unused_param_name]) + + # set statement_var_name == statement # + all_constraints_set = ( + all_constraints_set & islvars[statement_var_name_prime].eq_set( + islvars[0]+sid_before) + ) + all_constraints_set = ( + all_constraints_set & islvars[statement_var_name].eq_set( + islvars[0]+sid_after) + ) + + # convert constraint set to map + all_constraints_map = _convert_constraint_set_to_map( + all_constraints_set, len(all_dom_inames_ordered) + 1) # +1 for statement var + + # now apply domain sets to constraint variables + + # add statement variable to doms to enable intersection + range_to_intersect = add_dims_to_isl_set( + dom_after_constraint_set, isl.dim_type.out, + [statement_var_name], statement_var_pose) + domain_constraint_set = create_new_isl_set_with_primes( + dom_before_constraint_set, + marker="p") # TODO figure out before/after notation + domain_to_intersect = add_dims_to_isl_set( + domain_constraint_set, isl.dim_type.out, + [statement_var_name_prime], statement_var_pose) + + # insert inames missing from doms to enable intersection + domain_to_intersect = reorder_dims_by_name( + domain_to_intersect, isl.dim_type.out, + append_marker_to_strings( # TODO figure out before/after notation + [statement_var_name] + all_dom_inames_ordered, "p"), + add_missing=True) + range_to_intersect = reorder_dims_by_name( + range_to_intersect, + isl.dim_type.out, + [statement_var_name] + all_dom_inames_ordered, + add_missing=True) + + # intersect doms + map_with_loop_domain_constraints = all_constraints_map.intersect_domain( + domain_to_intersect).intersect_range(range_to_intersect) + + return map_with_loop_domain_constraints + + +def create_dependencies_from_legacy_knl(knl): + """Return a list of :class:`StatementPairDependySet` instances created + for a :class:`loopy.LoopKernel` containing legacy depencencies. Create + the new dependencies according to the following rules. (1) If + a dependency exists between ``insn0`` and ``insn1``, create the dependnecy + ``SAME(SNC)`` where ``SNC`` is the set of non-concurrent inames used + by both ``insn0`` and ``insn1``, and ``SAME`` is the relationship specified + by the ``SAME`` attribute of :class:`DependencyType`. (2) For each subset + of non-concurrent inames used by any instruction, find the set of all + instructions using those inames, create a directed graph with these + instructions as nodes and edges representing a 'happens before' + relationship specfied by each dependency, find the sources and sinks within + this graph, and connect each sink to each source (sink happens before + source) with a ``PRIOR(SNC)`` dependency, where ``PRIOR`` is the + relationship specified by the ``PRIOR`` attribute of + :class:`DependencyType`. + + """ + # Introduce SAME dep for set of shared, non-concurrent inames + + from schedule_checker.sched_check_utils import ( + get_concurrent_inames, + get_all_nonconcurrent_insn_iname_subsets, + get_sched_item_ids_within_inames, + ) + from schedule_checker.schedule import LexScheduleStatement + dt = DependencyType + conc_inames, non_conc_inames = get_concurrent_inames(knl) + statement_dep_sets = [] + for insn_after in knl.instructions: + for insn_before_id in insn_after.depends_on: + insn_before = knl.id_to_insn[insn_before_id] + insn_before_inames = insn_before.within_inames + insn_after_inames = insn_after.within_inames + shared_inames = insn_before_inames & insn_after_inames + shared_non_conc_inames = shared_inames & non_conc_inames + + statement_dep_sets.append( + StatementPairDependencySet( + LexScheduleStatement( + insn_id=insn_before.id, + within_inames=insn_before_inames), + LexScheduleStatement( + insn_id=insn_after.id, + within_inames=insn_after_inames), + {dt.SAME: shared_non_conc_inames}, + knl.get_inames_domain(insn_before_inames), + knl.get_inames_domain(insn_after_inames), + )) + + # loop-carried deps ------------------------------------------ + + # Go through insns and get all unique insn.depends_on iname sets + non_conc_iname_subsets = get_all_nonconcurrent_insn_iname_subsets( + knl, exclude_empty=True, non_conc_inames=non_conc_inames) + + # For each set of insns within a given iname set, find sources and sinks. + # Then make PRIOR dep from all sinks to all sources at previous iterations + for iname_subset in non_conc_iname_subsets: + # find items within this iname set + sched_item_ids = get_sched_item_ids_within_inames(knl, iname_subset) + + # find sources and sinks + sources, sinks = get_dependency_sources_and_sinks(knl, sched_item_ids) + + # create prior deps + + # in future, consider inserting single no-op source and sink + for source_id in sources: + for sink_id in sinks: + sink_insn_inames = knl.id_to_insn[sink_id].within_inames + source_insn_inames = knl.id_to_insn[source_id].within_inames + shared_inames = sink_insn_inames & source_insn_inames + shared_non_conc_inames = shared_inames & non_conc_inames + + statement_dep_sets.append( + StatementPairDependencySet( + LexScheduleStatement( + insn_id=sink_id, + within_inames=sink_insn_inames), + LexScheduleStatement( + insn_id=source_id, + within_inames=source_insn_inames), + {dt.PRIOR: shared_non_conc_inames}, + knl.get_inames_domain(sink_insn_inames), + knl.get_inames_domain(source_insn_inames), + )) + + return set(statement_dep_sets) + + +def get_dependency_sources_and_sinks(knl, sched_item_ids): + """Implicitly create a directed graph with the schedule items specified + by ``sched_item_ids`` as nodes, and with edges representing a + 'happens before' relationship specfied by each legacy dependency between + two instructions. Return the sources and sinks within this graph. + + .. arg sched_item_ids: A :class:`list` of :class:`str` representing + loopy instruction ids. + + .. return: Two instances of :class:`set` of :class:`str` instruction ids + representing the sources and sinks in the dependency graph. + + """ + sources = set() + dependees = set() # all dependees (within sched_item_ids) + for item_id in sched_item_ids: + # find the deps within sched_item_ids + deps = knl.id_to_insn[item_id].depends_on & sched_item_ids + if deps: + # add deps to dependees + dependees.update(deps) + else: # has no deps (within sched_item_ids), this is a source + sources.add(item_id) + + # sinks don't point to anyone + sinks = sched_item_ids - dependees + + return sources, sinks + + +class DependencyInfo(object): + # TODO rename + # TODO use Record? + def __init__( + self, + statement_pair_dep_set, + dom_before, + dom_after, + dep_constraint_map, + is_edge_in_dep_graph, # { dep & SAME } != empty + ): + self.statement_pair_dep_set = statement_pair_dep_set + self.dom_before = dom_before + self.dom_after = dom_after + self.dep_constraint_map = dep_constraint_map + self.is_edge_in_dep_graph = is_edge_in_dep_graph + + +def get_dependency_maps( + deps_and_domains, + schedule_items, # TODO always pass as strings since we only need the name? + loop_priority, + knl, # TODO avoid passing this in + ): + # TODO document + + dt = DependencyType + + # create map from loopy insn ids to ints + lp_insn_id_to_lex_sched_id = {} # TODO + next_sid = 0 + from loopy.schedule import Barrier, RunInstruction + for sched_item in schedule_items: + if isinstance(sched_item, (RunInstruction, Barrier)): + from schedule_checker.sched_check_utils import ( + _get_insn_id_from_sched_item, + ) + lp_insn_id = _get_insn_id_from_sched_item(sched_item) + lp_insn_id_to_lex_sched_id[lp_insn_id] = next_sid + next_sid += 1 + elif isinstance(sched_item, str): + # a string was passed, assume it's the insn_id + lp_insn_id_to_lex_sched_id[sched_item] = next_sid + next_sid += 1 + + from schedule_checker.sched_check_utils import ( + get_concurrent_inames, + ) + conc_inames, non_conc_inames = get_concurrent_inames(knl) + + dep_info_list = [] + for statement_pair_dep_set in deps_and_domains: + + dep_constraint_map = create_dependency_constraint( + statement_pair_dep_set, + loop_priority, + lp_insn_id_to_lex_sched_id, + "unused", # TODO shouldn't be necessary anymore + "statement", + ) + + # create "same" dep for these two insns + s_before = statement_pair_dep_set.statement_before + s_after = statement_pair_dep_set.statement_after + dom_before = statement_pair_dep_set.dom_before + dom_after = statement_pair_dep_set.dom_after + shared_nc_inames = ( + s_before.within_inames & s_after.within_inames & non_conc_inames) + same_dep_set = StatementPairDependencySet( + s_before, + s_after, + {dt.SAME: shared_nc_inames}, + dom_before, + dom_after, + ) + same_dep_constraint_map = create_dependency_constraint( + same_dep_set, + loop_priority, + lp_insn_id_to_lex_sched_id, + "unused", # TODO shouldn't be necessary + "statement", + ) + + # see whether we should create an edge in our statement dep graph + intersect_dep_and_same = same_dep_constraint_map & dep_constraint_map + intersect_not_empty = not bool(intersect_dep_and_same.is_empty()) + + # create a map representing constraints from the dependency, + # maps statement instance to all statement instances that must occur later + # TODO instead of tuple, store all this in a class + dep_info_list.append( + DependencyInfo( + statement_pair_dep_set, + dom_before, + dom_after, + dep_constraint_map, + intersect_not_empty, + ) + ) + print("") + + return dep_info_list diff --git a/loopy/schedule/schedule_checker/example_dependency_checking.py b/loopy/schedule/schedule_checker/example_dependency_checking.py new file mode 100644 index 0000000000000000000000000000000000000000..54ab553dbb0a499dbe9fc600905f743417c885cb --- /dev/null +++ b/loopy/schedule/schedule_checker/example_dependency_checking.py @@ -0,0 +1,189 @@ +import loopy as lp +from schedule_checker.dependency import ( # noqa + StatementPairDependencySet, + DependencyType as dt, + create_dependency_constraint, +) +from schedule_checker.lexicographic_order_map import ( + create_lex_order_map, + get_statement_ordering_map, +) +from schedule_checker.sched_check_utils import ( + prettier_map_string as pmap, + append_apostrophes, + create_explicit_map_from_tuples, + get_isl_space, +) +from schedule_checker.schedule import LexScheduleStatement + + +# make example kernel +knl = lp.make_kernel( + "{[i,j]: 0<=i,j<2}", + [ + "a[i,j] = b[i,j] {id=0}", + "a[i,j] = a[i,j] + 1 {id=1,dep=0}", + ], + name="example", + ) +knl = lp.tag_inames(knl, {"i": "l.0"}) +print("Kernel:") +print(knl) + +inames = ['i', 'j'] +statement_var = 's' +unused_param_name = 'unused' + +# example sched: +print("-"*80) + +# i is parallel, suppose we want to enforce the following: +# for a given i, statement 0 happens before statement 1 + +params_sched = ['p0', 'p1', unused_param_name] +in_names_sched = [statement_var]+inames +out_names_sched = ['l0', 'l1'] +sched_space = get_isl_space(params_sched, in_names_sched, out_names_sched) + +example_sched_valid = create_explicit_map_from_tuples( + [ + ((0, 0, 0), (0, 0)), + ((0, 1, 0), (0, 0)), + ((1, 0, 0), (0, 1)), + ((1, 1, 0), (0, 1)), + ((0, 0, 1), (1, 0)), + ((0, 1, 1), (1, 0)), + ((1, 0, 1), (1, 1)), + ((1, 1, 1), (1, 1)), + ], + sched_space, + ) +print("example sched (valid):") +print(pmap(example_sched_valid)) + +example_sched_invalid = create_explicit_map_from_tuples( + [ + ((0, 0, 0), (0, 0)), + ((0, 1, 0), (1, 1)), # these two are out of order, violation + ((1, 0, 0), (0, 1)), + ((1, 1, 0), (0, 1)), + ((0, 0, 1), (1, 0)), + ((0, 1, 1), (1, 0)), + ((1, 0, 1), (1, 1)), + ((1, 1, 1), (0, 0)), # these two are out of order, violation + ], + sched_space, + ) +print("example sched (invalid):") +print(pmap(example_sched_invalid)) + +# Lexicographic order map- map each tuple to all tuples occuring later +print("-"*80) +n_dims = 2 +lex_order_map = create_lex_order_map(n_dims) +print("lexicographic order map:") +print(pmap(lex_order_map)) + +# Statement instance ordering (valid sched) +print("-"*80) +SIO_valid = get_statement_ordering_map( + example_sched_valid, lex_order_map) +print("statement instance ordering (valid_sched):") +print(pmap(SIO_valid)) + +# Statement instance ordering (invalid sched) +print("-"*80) +SIO_invalid = get_statement_ordering_map( + example_sched_invalid, lex_order_map) +print("statement instance ordering (invalid_sched):") +print(pmap(SIO_invalid)) + +# Dependencies and constraints: +print("-"*80) + +# make some dependencies manually: + +s0 = LexScheduleStatement(insn_id="0", within_inames={"i", "j"}) +s1 = LexScheduleStatement(insn_id="1", within_inames={"i", "j"}) +insnid_to_int_sid = {"0": 0, "1": 1} + +dom_before = knl.get_inames_domain(s0.within_inames) +dom_after = knl.get_inames_domain(s1.within_inames) + +statement_pair_dep_set = StatementPairDependencySet( + s0, s1, {dt.SAME: ["i", "j"]}, dom_before, dom_after) +# SAME({i,j}) means: +# insn0{i,j} happens before insn1{i',j'} iff i = i' and j = j' + +print("Statement pair dependency set:") +print(statement_pair_dep_set) + +loop_priority = None +constraint_map = create_dependency_constraint( + statement_pair_dep_set, + loop_priority, + insnid_to_int_sid, + unused_param_name, + statement_var, + #all_dom_inames_ordered=inames, # not necessary since algin spaces below + ) +print("constraint map (before aligning space):") +print(pmap(constraint_map)) + +assert SIO_valid.space == SIO_invalid.space + +# align constraint map spaces to match sio so we can compare them + +print("constraint map space (before aligning):") +print(constraint_map.space) + +# align params +aligned_constraint_map = constraint_map.align_params(SIO_valid.space) + +# align in_ dims +import islpy as isl +from schedule_checker.sched_check_utils import ( + reorder_dims_by_name, +) +SIO_valid_in_names = SIO_valid.space.get_var_names(isl.dim_type.in_) +aligned_constraint_map = reorder_dims_by_name( + aligned_constraint_map, + isl.dim_type.in_, + SIO_valid_in_names, + add_missing=False, + new_names_are_permutation_only=True, + ) + +# align out dims +aligned_constraint_map = reorder_dims_by_name( + aligned_constraint_map, + isl.dim_type.out, + append_apostrophes(SIO_valid_in_names), + # TODO SIO out names are only pretending to have apostrophes; confusing + add_missing=False, + new_names_are_permutation_only=True, + ) + +assert aligned_constraint_map.space == SIO_valid.space +assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.in_) + == SIO_valid.space.get_var_names(isl.dim_type.in_)) +assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.out) + == append_apostrophes(SIO_valid.space.get_var_names(isl.dim_type.out))) +assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.param) + == SIO_valid.space.get_var_names(isl.dim_type.param)) + +print("constraint map space (after aligning):") +print(aligned_constraint_map.space) +print("constraint map (after aligning space):") +print(pmap(aligned_constraint_map)) +print("SIO space:") +print(SIO_valid.space) + +print("is valid sched valid?") +print(aligned_constraint_map.is_subset(SIO_valid)) + +print("is invalid sched valid?") +print(aligned_constraint_map.is_subset(SIO_invalid)) diff --git a/loopy/schedule/schedule_checker/example_lex_map_creation.py b/loopy/schedule/schedule_checker/example_lex_map_creation.py new file mode 100644 index 0000000000000000000000000000000000000000..83ff538d32deb5128ae74af765c1f94d0db63c40 --- /dev/null +++ b/loopy/schedule/schedule_checker/example_lex_map_creation.py @@ -0,0 +1,43 @@ +from schedule_checker.lexicographic_order_map import ( + get_statement_ordering_map, + create_lex_order_map, +) +from schedule_checker.sched_check_utils import ( + create_explicit_map_from_tuples, + get_isl_space, + prettier_map_string as pmap, +) + +# Lexicographic order map- map each tuple to all tuples occuring later + +n_dims = 2 +lex_order_map = create_lex_order_map(n_dims) +print("lexicographic order map:") +print(pmap(lex_order_map)) + +# Example *explicit* schedule (map statement instances to lex time) + +param_names_sched = [] +in_names_sched = ["s"] +out_names_sched = ["i", "j"] +sched_space = get_isl_space(param_names_sched, in_names_sched, out_names_sched) +sched_explicit = create_explicit_map_from_tuples( + [ + ((0,), (0, 0)), + ((1,), (0, 1)), + ((2,), (1, 0)), + ((3,), (1, 1)), + ], + sched_space, + ) +print("example explicit sched:") +print(pmap(sched_explicit)) + +# Statement instance ordering: +# map each statement instance to all statement instances that occur later +# S -> L -> S^-1 + +sio = get_statement_ordering_map( + sched_explicit, lex_order_map) +print("Statement instance ordering:") +print(pmap(sio)) diff --git a/loopy/schedule/schedule_checker/example_pairwise_schedule_validity.py b/loopy/schedule/schedule_checker/example_pairwise_schedule_validity.py new file mode 100644 index 0000000000000000000000000000000000000000..542f6ee6f24738fda8e304655a7670680ad717d4 --- /dev/null +++ b/loopy/schedule/schedule_checker/example_pairwise_schedule_validity.py @@ -0,0 +1,349 @@ +import loopy as lp +import numpy as np +from schedule_checker import ( + get_statement_pair_dependency_sets_from_legacy_knl, + check_schedule_validity, +) +from schedule_checker.sched_check_utils import ( + create_graph_from_pairs, +) +from schedule_checker.dependency import ( + get_dependency_maps, +) +from loopy.kernel import KernelState +from loopy import ( + preprocess_kernel, + get_one_scheduled_kernel, +) + +# Choose kernel ---------------------------------------------------------- + + +knl_choice = "example" +#knl_choice = "unused_inames" +#knl_choice = "matmul" +#knl_choice = "scan" +#knl_choice = "dependent_domain" +#knl_choice = "stroud_bernstein_orig" # TODO invalid sched? +#knl_choice = "ilp_kernel" +#knl_choice = "add_barrier" +#knl_choice = "nop" +#knl_choice = "nest_multi_dom" +#knl_choice = "loop_carried_deps" + +if knl_choice == "example": + knl = lp.make_kernel( + [ + "{[i,ii]: 0<=itemp = b[i,k] {id=insn_a} + end + for j + a[i,j] = temp + 1 {id=insn_b,dep=insn_a} + c[i,j] = d[i,j] {id=insn_c} + end + end + for t + e[t] = f[t] {id=insn_d} + end + """, + name="example", + assumptions="pi,pj,pk,pt >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes( + knl, + {"b": np.float32, "d": np.float32, "f": np.float32}) + #knl = lp.tag_inames(knl, {"i": "l.0"}) + #knl = lp.prioritize_loops(knl, "i,k,j") + knl = lp.prioritize_loops(knl, "i,k") + knl = lp.prioritize_loops(knl, "i,j") +if knl_choice == "unused_inames": + knl = lp.make_kernel( + [ + "{[i,ii]: 0<=itemp = b[i,k] {id=insn_a} + end + for j + a[i,j] = temp + 1 {id=insn_b,dep=insn_a} + end + end + """, + name="unused_inames", + assumptions="pi,pj,pk >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes( + knl, + {"b": np.float32}) + #knl = lp.tag_inames(knl, {"i": "l.0"}) + #knl = lp.prioritize_loops(knl, "i,k,j") + knl = lp.prioritize_loops(knl, "i,k") + knl = lp.prioritize_loops(knl, "i,j") +elif knl_choice == "matmul": + bsize = 16 + knl = lp.make_kernel( + "{[i,k,j]: 0<=i {[i,j]: 0<=i {[i]: 0<=i xi = qpts[1, i2] + <> s = 1-xi + <> r = xi/s + <> aind = 0 {id=aind_init} + for alpha1 + <> w = s**(deg-alpha1) {id=init_w} + for alpha2 + tmp[el,alpha1,i2] = tmp[el,alpha1,i2] + w * coeffs[aind] \ + {id=write_tmp,dep=init_w:aind_init} + w = w * r * ( deg - alpha1 - alpha2 ) / (1 + alpha2) \ + {id=update_w,dep=init_w:write_tmp} + aind = aind + 1 \ + {id=aind_incr,dep=aind_init:write_tmp:update_w} + end + end + end + """, + [lp.GlobalArg("coeffs", None, shape=None), "..."], + name="stroud_bernstein_orig", assumptions="deg>=0 and nels>=1") + knl = lp.add_and_infer_dtypes(knl, + dict(coeffs=np.float32, qpts=np.int32)) + knl = lp.fix_parameters(knl, nqp1d=7, deg=4) + knl = lp.split_iname(knl, "el", 16, inner_tag="l.0") + knl = lp.split_iname(knl, "el_outer", 2, outer_tag="g.0", + inner_tag="ilp", slabs=(0, 1)) + knl = lp.tag_inames(knl, dict(i2="l.1", alpha1="unr", alpha2="unr")) + # Must declare coeffs to have "no" shape, to keep loopy + # from trying to figure it out the shape automatically. +elif knl_choice == "ilp_kernel": + knl = lp.make_kernel( + "{[i,j,ilp_iname]: 0 <= i,j < n and 0 <= ilp_iname < 4}", + """ + for i + for j + for ilp_iname + tmp[i,j,ilp_iname] = 3.14 + end + end + end + """, + name="ilp_kernel", + assumptions="n>=1 and n mod 4 = 0", + ) + # TODO why is conditional on ilp_name? + knl = lp.tag_inames(knl, {"j": "l.0", "ilp_iname": "ilp"}) + #knl = lp.prioritize_loops(knl, "i_outer_outer,i_outer_inner,i_inner,a") +if knl_choice == "add_barrier": + np.random.seed(17) + #a = np.random.randn(16) + cnst = np.random.randn(16) + knl = lp.make_kernel( + "{[i, ii]: 0<=i, ii c_end = 2 + for c + ... nop + end + end + """, + "...", + seq_dependencies=True) + knl = lp.fix_parameters(knl, dim=3) +if knl_choice == "nest_multi_dom": + #"{[i,j,k]: 0<=i,j,kacc = 0 {id=insn0} + for j + for k + acc = acc + j + k {id=insn1,dep=insn0} + end + end + end + end + """, + name="nest_multi_dom", + #assumptions="n >= 1", + assumptions="ni,nj,nk,nx >= 1", + lang_version=(2018, 2) + ) + """ + <>foo = 0 {id=insn0} + for i + <>acc = 0 {id=insn1} + for j + for k + acc = acc + j + k {id=insn2,dep=insn1} + end + end + foo = foo + acc {id=insn3,dep=insn2} + end + <>bar = foo {id=insn4,dep=insn3} + """ + knl = lp.prioritize_loops(knl, "x,xx,i") + knl = lp.prioritize_loops(knl, "i,j") + knl = lp.prioritize_loops(knl, "j,k") + +if knl_choice == "loop_carried_deps": + knl = lp.make_kernel( + "{[i]: 0<=iacc0 = 0 {id=insn0} + for i + acc0 = acc0 + i {id=insn1,dep=insn0} + <>acc2 = acc0 + i {id=insn2,dep=insn1} + <>acc3 = acc2 + i {id=insn3,dep=insn2} + <>acc4 = acc0 + i {id=insn4,dep=insn1} + end + """, + name="loop_carried_deps", + assumptions="n >= 1", + lang_version=(2018, 2) + ) + +unprocessed_knl = knl.copy() + +legacy_deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl( + unprocessed_knl) + +# get a schedule to check +if knl.state < KernelState.PREPROCESSED: + knl = preprocess_kernel(knl) +knl = get_one_scheduled_kernel(knl) +print("kernel schedueld") +schedule_items = knl.schedule +print("checking validity") + +sched_is_valid = check_schedule_validity( + unprocessed_knl, legacy_deps_and_domains, schedule_items, verbose=True) + +""" +legacy_deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl(knl) + +# get a schedule to check +from loopy import get_one_scheduled_kernel +scheduled_knl = get_one_scheduled_kernel(knl) +schedule_items = scheduled_knl.schedule + +sched_is_valid = check_schedule_validity( + knl, legacy_deps_and_domains, schedule_items, verbose=True) +""" + +print("is sched valid? constraint map subset of SIO?") +print(sched_is_valid) + + +print("="*80) +print("testing dep sort") +print("="*80) + +# create maps representing legacy deps +# (includes bool representing result of test for dep graph edge) +legacy_dep_info_list = get_dependency_maps( + legacy_deps_and_domains, + schedule_items, + knl.loop_priority, + knl, + ) + +# get dep graph edges +dep_graph_pairs = [ + ( + dep.statement_pair_dep_set.statement_before.insn_id, + dep.statement_pair_dep_set.statement_after.insn_id + ) + for dep in legacy_dep_info_list if dep.is_edge_in_dep_graph] + +# create dep graph from edges +dep_graph = create_graph_from_pairs(dep_graph_pairs) + +print("dep_graph:") +for k, v in dep_graph.items(): + print("%s: %s" % (k, v)) diff --git a/loopy/schedule/schedule_checker/example_wave_equation.py b/loopy/schedule/schedule_checker/example_wave_equation.py new file mode 100644 index 0000000000000000000000000000000000000000..6afa3044b1fcbc23240eb87bdfce633f34f37dc1 --- /dev/null +++ b/loopy/schedule/schedule_checker/example_wave_equation.py @@ -0,0 +1,650 @@ +import loopy as lp +from loopy import generate_code_v2 +from loopy import get_one_scheduled_kernel +from loopy.kernel import KernelState +from loopy import preprocess_kernel +import numpy as np +import islpy as isl +#from loopy.kernel_stat_collector import KernelStatCollector +#from loopy.kernel_stat_collector import KernelStatOptions as kso # noqa +from schedule_checker import check_schedule_validity +from schedule_checker.sched_check_utils import ( + prettier_map_string, + reorder_dims_by_name, + append_apostrophes, + append_marker_to_isl_map_var_names, +) +from schedule_checker.dependency import ( + create_arbitrary_dependency_constraint, +) +from dependency import _create_5pt_stencil_dependency_constraint +from schedule_checker.schedule import LexSchedule +from schedule_checker.lexicographic_order_map import ( + get_statement_ordering_map, +) + +# Make kernel ---------------------------------------------------------- + +# u[x,t+1] = 2*u[x,t] - u[x,t-1] + c*(dt/dx)**2*(u[x+1,t] - 2*u[x,t] + u[x-1,t]) +# mine, works: +# "{[x,t]: 1<=x {[ix, it]: 1<=ix {[ix, it]: 1<=ix lex time):") + #print(sched_map_symbolic.space) + #print("-"*80) + +# }}} + +# get map representing lexicographic ordering +lex_order_map_symbolic = sched.get_lex_order_map_for_symbolic_sched() + +# {{{ verbose + +""" +if verbose: + print("lex order map symbolic:") + print(prettier_map_string(lex_order_map_symbolic)) + print("space (lex time -> lex time):") + print(lex_order_map_symbolic.space) + print("-"*80) +""" + +# }}} + +# create statement instance ordering, +# maps each statement instance to all statement instances occuring later +sio = get_statement_ordering_map( + sched_map_symbolic_before, + sched_map_symbolic_after, + lex_order_map_symbolic, + before_marker="p") + +# {{{ verbose + +if verbose: + print("statement instance ordering:") + print(prettier_map_string(sio)) + print("SIO space (statement instances -> statement instances):") + print(sio.space) + print("-"*80) + +if verbose: + print("constraint map space (before aligning):") + print(constraint_map.space) + +# }}} + +# align constraint map spaces to match sio so we can compare them +# align params +aligned_constraint_map = constraint_map.align_params(sio.space) +#print(prettier_map_string(aligned_constraint_map)) + +# align in_ dims +sio_in_names = sio.space.get_var_names(isl.dim_type.in_) +aligned_constraint_map = reorder_dims_by_name( + aligned_constraint_map, + isl.dim_type.in_, + sio_in_names, + add_missing=False, + new_names_are_permutation_only=True, + ) + +# align out dims +sio_out_names = sio.space.get_var_names(isl.dim_type.out) +aligned_constraint_map = reorder_dims_by_name( + aligned_constraint_map, + isl.dim_type.out, + sio_out_names, + add_missing=False, + new_names_are_permutation_only=True, + ) + +# {{{ verbose + +if verbose: + print("constraint map space (after aligning):") + print(aligned_constraint_map.space) + print("constraint map:") + print(prettier_map_string(aligned_constraint_map)) + +# }}} + +assert aligned_constraint_map.space == sio.space +assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.in_) + == sio.space.get_var_names(isl.dim_type.in_)) +assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.out) + == sio.space.get_var_names(isl.dim_type.out)) +assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.param) + == sio.space.get_var_names(isl.dim_type.param)) + +sched_is_valid = aligned_constraint_map.is_subset(sio) + +if not sched_is_valid: + + # {{{ verbose + + if verbose: + print("================ constraint check failure =================") + print("constraint map not subset of SIO") + print("dependency:") + print(prettier_map_string(constraint_map)) + print("statement instance ordering:") + print(prettier_map_string(sio)) + print("constraint_map.gist(sio):") + print(aligned_constraint_map.gist(sio)) + print("sio.gist(constraint_map)") + print(sio.gist(aligned_constraint_map)) + print("loop priority known:") + print(preprocessed_knl.loop_priority) + """ + from schedule_checker.sched_check_utils import ( + get_concurrent_inames, + ) + conc_inames, non_conc_inames = get_concurrent_inames(scheduled_knl) + print("concurrent inames:", conc_inames) + print("sequential inames:", non_conc_inames) + print("constraint map space (stmt instances -> stmt instances):") + print(aligned_constraint_map.space) + print("SIO space (statement instances -> statement instances):") + print(sio.space) + print("constraint map:") + print(prettier_map_string(aligned_constraint_map)) + print("statement instance ordering:") + print(prettier_map_string(sio)) + print("{insn id -> sched sid int} dict:") + print(lp_insn_id_to_lex_sched_id) + """ + print("===========================================================") + + # }}} + +print("is sched valid? constraint map subset of SIO?") +print(sched_is_valid) + + +# ====================================================================== +# now do this with complicated mapping + + +# create mapping: +# old (wrong) +""" +m = isl.BasicMap( + "[nx,nt] -> {[ix, it] -> [tx, tt, tparity, itt, itx]: " + "16*(tx - tt + tparity) + itx - itt = ix - it and " + "16*(tx + tt) + itt + itx = ix + it and " + "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}") +m2 = isl.BasicMap( + "[nx,nt,unused] -> {[statement, ix, it] -> [statement'=statement, tx, tt, tparity, itt, itx]: " + "16*(tx - tt + tparity) + itx - itt = ix - it and " + "16*(tx + tt) + itt + itx = ix + it and " + "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}") +m2_prime = isl.BasicMap( + "[nx,nt,unused] -> {[statement, ix, it] -> [statement'=statement, tx', tt', tparity', itt', itx']: " + "16*(tx' - tt' + tparity') + itx' - itt' = ix - it and " + "16*(tx' + tt') + itt' + itx' = ix + it and " + "0<=tparity'<2 and 0 <= itx' - itt' < 16 and 0 <= itt'+itx' < 16}") +""" + +# new +m = isl.BasicMap( + "[nx,nt] -> {[ix, it] -> [tx, tt, tparity, itt, itx]: " + "16*(tx - tt) + itx - itt = ix - it and " + "16*(tx + tt + tparity) + itt + itx = ix + it and " + "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}") +m2 = isl.BasicMap( + "[nx,nt,unused] -> {[statement, ix, it] -> [statement'=statement, tx, tt, tparity, itt, itx]: " + "16*(tx - tt) + itx - itt = ix - it and " + "16*(tx + tt + tparity) + itt + itx = ix + it and " + "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}") +#m2_primes_after = isl.BasicMap( +# "[nx,nt,unused] -> {[statement, ix, it] -> [statement'=statement, tx', tt', tparity', itt', itx']: " +# "16*(tx' - tt') + itx' - itt' = ix - it and " +# "16*(tx' + tt' + tparity') + itt' + itx' = ix + it and " +# "0<=tparity'<2 and 0 <= itx' - itt' < 16 and 0 <= itt'+itx' < 16}") +m2_prime = isl.BasicMap( + "[nx,nt,unused] -> {[statement', ix', it'] -> [statement=statement', tx, tt, tparity, itt, itx]: " + "16*(tx - tt) + itx - itt = ix' - it' and " + "16*(tx + tt + tparity) + itt + itx = ix' + it' and " + "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}") + +# TODO note order must match statement_iname_premap_order + +print("maping:") +print(prettier_map_string(m2)) + +# new kernel +knl = lp.map_domain(ref_knl, m) +knl = lp.prioritize_loops(knl, "tt,tparity,tx,itt,itx") +print("code after mapping:") +print(generate_code_v2(knl).device_code()) +#1/0 + +print("constraint_map before apply_range:") +print(prettier_map_string(constraint_map)) +#mapped_constraint_map = constraint_map.apply_range(m2_prime) +mapped_constraint_map = constraint_map.apply_range(m2) +print("constraint_map after apply_range:") +print(prettier_map_string(mapped_constraint_map)) +#mapped_constraint_map = mapped_constraint_map.apply_domain(m2) +mapped_constraint_map = mapped_constraint_map.apply_domain(m2_prime) +# put primes on *before* names +mapped_constraint_map = append_marker_to_isl_map_var_names( + mapped_constraint_map, isl.dim_type.in_, marker="'") + +print("constraint_map after apply_domain:") +print(prettier_map_string(mapped_constraint_map)) + +statement_inames_mapped = set(["itx","itt","tt","tparity","tx"]) +sid_before = 0 +sid_after = 0 + +if knl.state < KernelState.PREPROCESSED: + preprocessed_knl = preprocess_kernel(knl) +else: + preprocessed_knl = knl +inames_domain_before_mapped = preprocessed_knl.get_inames_domain(statement_inames_mapped) +inames_domain_after_mapped = preprocessed_knl.get_inames_domain(statement_inames_mapped) +print("(mapped) inames_domain_before:", inames_domain_before_mapped) +print("(mapped) inames_domain_after:", inames_domain_after_mapped) + +# ============================================= + +verbose = False +verbose = True + +# get a schedule to check +if preprocessed_knl.schedule is None: + scheduled_knl = get_one_scheduled_kernel(preprocessed_knl) +else: + scheduled_knl = preprocessed_knl + +# {{{ verbose + +if verbose: + # Print kernel info ------------------------------------------------------ + print("="*80) + print("Kernel:") + print(scheduled_knl) + #print(generate_code_v2(scheduled_knl).device_code()) + print("="*80) + print("Iname tags: %s" % (scheduled_knl.iname_to_tags)) + print("="*80) + print("Loopy schedule:") + for sched_item in scheduled_knl.schedule: + print(sched_item) + #print("scheduled iname order:") + #print(sched_iname_order) + + print("="*80) + print("inames_domain_before_mapped:", inames_domain_before_mapped) + print("inames_domain_after_mapped:", inames_domain_after_mapped) + +# }}} + +# Create a mapping of {statement instance: lex point} +# including only instructions involved in this dependency +sched = LexSchedule( + scheduled_knl, + scheduled_knl.schedule, + str(sid_before), + str(sid_after) + ) + +# Get an isl map representing the LexSchedule; +# this requires the iname domains + +# get a mapping from lex schedule id to relevant inames domain +sid_to_dom = { + sid_before: inames_domain_before_mapped, + sid_after: inames_domain_after_mapped, + } + +#sched_map_symbolic = sched.create_symbolic_isl_map(sid_to_dom) +sched_map_symbolic_before, sched_map_symbolic_after = sched.create_symbolic_isl_maps( + inames_domain_before_mapped, inames_domain_after_mapped) + +# {{{ verbose + +if verbose: + print("sid_to_dom:\n", sid_to_dom) + print("LexSchedule after creating symbolic isl map:") + print(sched) + print("LexSched:") + print(prettier_map_string(sched_map_symbolic_before)) + print(prettier_map_string(sched_map_symbolic_after)) + #print("space (statement instances -> lex time):") + #print(sched_map_symbolic.space) + #print("-"*80) + +# }}} + +# get map representing lexicographic ordering +lex_order_map_symbolic = sched.get_lex_order_map_for_symbolic_sched() + +# {{{ verbose + +""" +if verbose: + print("lex order map symbolic:") + print(prettier_map_string(lex_order_map_symbolic)) + print("space (lex time -> lex time):") + print(lex_order_map_symbolic.space) + print("-"*80) +""" + +# }}} + +# create statement instance ordering, +# maps each statement instance to all statement instances occuring later +sio = get_statement_ordering_map( + sched_map_symbolic_before, + sched_map_symbolic_after, + lex_order_map_symbolic, + before_marker="'") + +# {{{ verbose + +if verbose: + print("statement instance ordering:") + print(prettier_map_string(sio)) + print("SIO space (statement instances -> statement instances):") + print(sio.space) + print("-"*80) + +if verbose: + print("constraint map space (before aligning):") + print(constraint_map.space) + +# }}} + +# align constraint map spaces to match sio so we can compare them +# align params +aligned_constraint_map = mapped_constraint_map.align_params(sio.space) +#print(prettier_map_string(aligned_constraint_map)) + +# align in_ dims +sio_in_names = sio.space.get_var_names(isl.dim_type.in_) +aligned_constraint_map = reorder_dims_by_name( + aligned_constraint_map, + isl.dim_type.in_, + sio_in_names, + add_missing=False, + new_names_are_permutation_only=True, + ) + +# align out dims +sio_out_names = sio.space.get_var_names(isl.dim_type.out) +aligned_constraint_map = reorder_dims_by_name( + aligned_constraint_map, + isl.dim_type.out, + sio_out_names, + add_missing=False, + new_names_are_permutation_only=True, +) + +# {{{ verbose + +if verbose: + print("constraint map space (after aligning):") + print(aligned_constraint_map.space) + print("constraint map:") + print(prettier_map_string(aligned_constraint_map)) + +# }}} + +assert aligned_constraint_map.space == sio.space +assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.in_) + == sio.space.get_var_names(isl.dim_type.in_)) +assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.out) + == sio.space.get_var_names(isl.dim_type.out)) +assert ( + aligned_constraint_map.space.get_var_names(isl.dim_type.param) + == sio.space.get_var_names(isl.dim_type.param)) + +sched_is_valid = aligned_constraint_map.is_subset(sio) + +if not sched_is_valid: + + # {{{ verbose + + if verbose: + print("================ constraint check failure =================") + print("constraint map not subset of SIO") + print("dependency:") + print(prettier_map_string(constraint_map)) + print("statement instance ordering:") + print(prettier_map_string(sio)) + print("constraint_map.gist(sio):") + print(aligned_constraint_map.gist(sio)) + print("sio.gist(constraint_map)") + print(sio.gist(aligned_constraint_map)) + print("loop priority known:") + print(preprocessed_knl.loop_priority) + """ + from schedule_checker.sched_check_utils import ( + get_concurrent_inames, + ) + conc_inames, non_conc_inames = get_concurrent_inames(scheduled_knl) + print("concurrent inames:", conc_inames) + print("sequential inames:", non_conc_inames) + print("constraint map space (stmt instances -> stmt instances):") + print(aligned_constraint_map.space) + print("SIO space (statement instances -> statement instances):") + print(sio.space) + print("constraint map:") + print(prettier_map_string(aligned_constraint_map)) + print("statement instance ordering:") + print(prettier_map_string(sio)) + print("{insn id -> sched sid int} dict:") + print(lp_insn_id_to_lex_sched_id) + """ + print("===========================================================") + + # }}} + +print("is sched valid? constraint map subset of SIO?") +print(sched_is_valid) + + + + + +''' +# (U_n^{k+1}-U_n^k)/dt = C*(U_{n+1}^k-U_n^k)/dx +# U_n^{k+1} = U_n^k + dt/dx*C*(U_{n+1}^k-U_n^k) +''' + +# Get stats ---------------------------------------------------------- + +""" +sc = KernelStatCollector( + evaluate_polys=False, + count_madds=False, # TODO enable after madd counting branch is merged + ) +#nx = 2**11 +#nt = 2**11 +nx = 2**5 +nt = 2**5 +param_dict = {"nx": nx, "nt": nt, "c": 1, "dt": 0.1, "dx": 0.1} +stat_list = [kso.WALL_TIME, kso.OP_MAP, kso.FLOP_RATE] +stats = sc.collect_stats(knl, stat_list, param_dict=param_dict) + +# Measured time + flop rate +time_measured = stats[kso.WALL_TIME] +#flop_rate_measured = stats[kso.FLOP_RATE] + +print("time:", time_measured) +""" + +""" +sched_is_valid = check_schedule_validity(knl, verbose=True) + +print("is sched valid? constraint map subset of SIO?") +print(sched_is_valid) +""" + +""" +sched_is_valid = check_schedule_validity(knl, verbose=True) + +print("is sched valid? constraint map subset of SIO?") +print(sched_is_valid) +""" diff --git a/loopy/schedule/schedule_checker/lexicographic_order_map.py b/loopy/schedule/schedule_checker/lexicographic_order_map.py new file mode 100644 index 0000000000000000000000000000000000000000..7abe6b0c5e90409e0e23480937fa453489396fc9 --- /dev/null +++ b/loopy/schedule/schedule_checker/lexicographic_order_map.py @@ -0,0 +1,137 @@ +import islpy as isl + + +def get_statement_ordering_map( + sched_map_before, sched_map_after, lex_map, before_marker="'"): + """Return a mapping that maps each statement instance to + all statement instances occuring later. + + .. arg sched_map_before: An :class:`islpy.Map` representing instruction + instance order for the dependee as a mapping from each statement + instance to a point in the lexicographic ordering. + + .. arg sched_map_after: An :class:`islpy.Map` representing instruction + instance order for the depender as a mapping from each statement + instance to a point in the lexicographic ordering. + + .. arg lex_map: An :class:`islpy.Map` representing a lexicographic + ordering as a mapping from each point in lexicographic time + to every point that occurs later in lexicographic time. E.g.:: + + {[i0', i1', i2', ...] -> [i0, i1, i2, ...] : + i0' < i0 or (i0' = i0 and i1' < i1) + or (i0' = i0 and i1' = i1 and i2' < i2) ...} + + .. return: An :class:`islpy.Map` representing the lex schedule as + a mapping from each statement instance to all statement instances + occuring later. I.e., we compose B -> L -> A^-1, where B + is sched_map_before, A is sched_map_after, and L is the + lexicographic ordering map. + + """ + + sio = sched_map_before.apply_range( + lex_map).apply_range(sched_map_after.reverse()) + # append marker to in names + for i in range(sio.dim(isl.dim_type.in_)): + sio = sio.set_dim_name(isl.dim_type.in_, i, sio.get_dim_name( + isl.dim_type.in_, i)+before_marker) + return sio + + +def get_lex_order_constraint(islvars, before_names, after_names): + """Return a constraint represented as an :class:`islpy.Set` + defining a 'happens before' relationship in a lexicographic + ordering. + + .. arg islvars: A dictionary from variable names to :class:`PwAff` + instances that represent each of the variables + (islvars may be produced by `islpy.make_zero_and_vars`). The key + '0' is also include and represents a :class:`PwAff` zero constant. + This dictionary defines the space to be used for the set. + + .. arg before_names: A list of :class:`str` variable names representing + the lexicographic space dimensions for a point in lexicographic + time that occurs before. (see example below) + + .. arg after_names: A list of :class:`str` variable names representing + the lexicographic space dimensions for a point in lexicographic + time that occurs after. (see example below) + + .. return: An :class:`islpy.Set` representing a constraint that enforces a + lexicographic ordering. E.g., if ``before_names = [i0', i1', i2']`` and + ``after_names = [i0, i1, i2]``, return the set:: + + {[i0', i1', i2', i0, i1, i2] : + i0' < i0 or (i0' = i0 and i1' < i1) + or (i0' = i0 and i1' = i1 and i2' < i2)} + + """ + + lex_order_constraint = islvars[before_names[0]].lt_set(islvars[after_names[0]]) + for i in range(1, len(before_names)): + lex_order_constraint_conj = islvars[before_names[i]].lt_set( + islvars[after_names[i]]) + for j in range(i): + lex_order_constraint_conj = lex_order_constraint_conj & \ + islvars[before_names[j]].eq_set(islvars[after_names[j]]) + lex_order_constraint = lex_order_constraint | lex_order_constraint_conj + return lex_order_constraint + + +def create_lex_order_map( + n_dims, + before_names=None, + after_names=None, + ): + """Return a mapping that maps each point in a lexicographic + ordering to every point that occurs later in lexicographic + time. + + .. arg n_dims: An :class:`int` representing the number of dimensions + in the lexicographic ordering. + + .. arg before_names: A list of :class:`str` variable names representing + the lexicographic space dimensions for a point in lexicographic + time that occurs before. (see example below) + + .. arg after_names: A list of :class:`str` variable names representing + the lexicographic space dimensions for a point in lexicographic + time that occurs after. (see example below) + + .. return: An :class:`islpy.Map` representing a lexicographic + ordering as a mapping from each point in lexicographic time + to every point that occurs later in lexicographic time. + E.g., if ``before_names = [i0', i1', i2']`` and + ``after_names = [i0, i1, i2]``, return the map:: + + {[i0', i1', i2'] -> [i0, i1, i2] : + i0' < i0 or (i0' = i0 and i1' < i1) + or (i0' = i0 and i1' = i1 and i2' < i2)} + + """ + + if before_names is None: + before_names = ["i%s" % (i) for i in range(n_dims)] + if after_names is None: + from schedule_checker.sched_check_utils import ( + append_marker_to_strings, + ) + after_names = append_marker_to_strings(before_names, marker="_") + + assert len(before_names) == len(after_names) == n_dims + dim_type = isl.dim_type + + islvars = isl.make_zero_and_vars( + before_names+after_names, + []) + + lex_order_constraint = get_lex_order_constraint( + islvars, before_names, after_names) + + lex_map = isl.Map.from_domain(lex_order_constraint) + lex_map = lex_map.move_dims( + dim_type.out, 0, dim_type.in_, + len(before_names), len(after_names)) + + return lex_map diff --git a/loopy/schedule/schedule_checker/sched_check_utils.py b/loopy/schedule/schedule_checker/sched_check_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6fefa14831d83c37c8d56c827a0de831055702d8 --- /dev/null +++ b/loopy/schedule/schedule_checker/sched_check_utils.py @@ -0,0 +1,552 @@ +import islpy as isl + + +# TODO remove assertions once satisified they are unnecessary +# TODO update all documentation/comments after apostrophe switched to +# *before* statement/inames + + +def prettier_map_string(isl_map): + return str(isl_map + ).replace("{ ", "{\n").replace(" }", "\n}").replace("; ", ";\n") + + +def get_islvars_from_space(space): + param_names = space.get_var_names(isl.dim_type.param) + in_names = space.get_var_names(isl.dim_type.in_) + out_names = space.get_var_names(isl.dim_type.out) + return isl.make_zero_and_vars(in_names+out_names, param_names) + + +def add_dims_to_isl_set(isl_set, dim_type, names, new_pose_start): + new_set = isl_set.insert_dims( + dim_type, new_pose_start, len(names) + ).set_dim_name(dim_type, new_pose_start, names[0]) + for i, name in enumerate(names[1:]): + new_set = new_set.set_dim_name(dim_type, new_pose_start+1+i, name) + return new_set + + +def reorder_dims_by_name( + isl_set, dim_type, desired_dims_ordered, + add_missing=False, new_names_are_permutation_only=False): + """Return an isl_set with the dimensions in the specified order. + + .. arg isl_set: A :class:`islpy.Set` whose dimensions are + to be reordered. + + .. arg dim_type: A :class:`islpy.dim_type`, i.e., an :class:`int`, + specifying the dimension to be reordered. + + .. arg desired_dims_ordered: A :class:`list` of :class:`str` elements + representing the desired dimensions order by dimension name. + + .. arg add_missing: A :class:`bool` specifying whether to insert + dimensions (by name) found in `desired_dims_ordered` that are not + present in `isl_set`. + + .. arg new_names_are_permutation_only: A :class:`bool` indicating that + `desired_dims_ordered` contains the same names as the specified + dimensions in `isl_set`, and does not, e.g., contain additional + dimension names not found in `isl_set`. If set to True, and these + two sets of names do not match, an error is produced. + + .. return: An :class:`islpy.Set` matching `isl_set` with the + dimension order matching `desired_dims_ordered`, optionally + including additional dimensions present in `desred_dims_ordered` + that are not present in `isl_set`. + + """ + + assert set(isl_set.get_var_names(dim_type)).issubset(desired_dims_ordered) + assert dim_type != isl.dim_type.param + + if new_names_are_permutation_only and ( + set(isl_set.get_var_names(dim_type)) + != set(desired_dims_ordered)): + raise ValueError( + "Var name sets must match with new_names_are_permutation_only=True. " + "isl vars: %s, desired dims: %s" + % (isl_set.get_var_names(dim_type), desired_dims_ordered)) + + other_dim_type = isl.dim_type.param + other_dim_len = len(isl_set.get_var_names(other_dim_type)) + + new_set = isl_set.copy() + for desired_pose, name in enumerate(desired_dims_ordered): + # if iname doesn't exist in set, add dim: + if name not in new_set.get_var_names(dim_type): + if add_missing: + # insert missing dim in correct location + new_set = new_set.insert_dims( + dim_type, desired_pose, 1 + ).set_dim_name( + dim_type, desired_pose, name) + else: # iname exists in set + current_pose = new_set.find_dim_by_name(dim_type, name) + if current_pose != desired_pose: + # move_dims(dst_type, dst_pose, src_type, src_pose, n) + + # first move to other dim because isl is stupid + new_set = new_set.move_dims( + other_dim_type, other_dim_len, dim_type, current_pose, 1) + # TODO is this safe? + # now move it where we actually want it + new_set = new_set.move_dims( + dim_type, desired_pose, other_dim_type, other_dim_len, 1) + + return new_set + + +def create_new_isl_set_with_primes(old_isl_set, marker="'"): + """Return an isl_set with apostrophes appended to + dim_type.set dimension names. + + .. arg old_isl_set: A :class:`islpy.Set`. + + .. return: A :class:`islpy.Set` matching `old_isl_set` with + apostrophes appended to dim_type.set dimension names. + + """ + # TODO this is just a special case of append_marker_to_isl_map_var_names + + new_set = old_isl_set.copy() + for i in range(old_isl_set.n_dim()): + new_set = new_set.set_dim_name( + isl.dim_type.set, i, old_isl_set.get_dim_name( + isl.dim_type.set, i)+marker) + return new_set + + +def append_marker_to_isl_map_var_names(old_isl_map, dim_type, marker="'"): + """Return an isl_map with marker appended to + dim_type dimension names. + + .. arg old_isl_map: A :class:`islpy.Map`. + + .. arg dim_type: A :class:`islpy.dim_type`, i.e., an :class:`int`, + specifying the dimension to be marked. + + .. return: A :class:`islpy.Map` matching `old_isl_map` with + apostrophes appended to dim_type dimension names. + + """ + + new_map = old_isl_map.copy() + for i in range(len(old_isl_map.get_var_names(dim_type))): + new_map = new_map.set_dim_name(dim_type, i, old_isl_map.get_dim_name( + dim_type, i)+marker) + return new_map + + +def make_islvars_with_marker( + var_names_needing_marker, other_var_names, param_names, marker="'"): + """Return a dictionary from variable and parameter names + to :class:`PwAff` instances that represent each of + the variables and parameters, appending marker to + var_names_needing_marker. + + .. arg var_names_needing_marker: A :class:`list` of :class:`str` + elements representing variable names to have markers appended. + + .. arg other_var_names: A :class:`list` of :class:`str` + elements representing variable names to be included as-is. + + .. arg param_names: A :class:`list` of :class:`str` elements + representing parameter names. + + .. return: A dictionary from variable names to :class:`PwAff` + instances that represent each of the variables + (islvars may be produced by `islpy.make_zero_and_vars`). The key + '0' is also include and represents a :class:`PwAff` zero constant. + + """ + + def append_marker(l, mark): + new_l = [] + for s in l: + new_l.append(s+mark) + return new_l + + return isl.make_zero_and_vars( + append_marker(var_names_needing_marker, marker) + + other_var_names, param_names) + + +def append_marker_to_strings(strings, marker="'"): + if not isinstance(strings, list): + raise ValueError("append_marker_to_strings did not receive a list") + else: + return [s+marker for s in strings] + + +def append_apostrophes(strings): + return append_marker_to_strings(strings, marker="'") + + +def _union_of_isl_sets_or_maps(set_list): + union = set_list[0] + for s in set_list[1:]: + union = union.union(s) + return union + + +def list_var_names_in_isl_sets( + isl_sets, + set_dim=isl.dim_type.set): + inames = set() + for isl_set in isl_sets: + inames.update(isl_set.get_var_names(set_dim)) + return list(inames) + + +def create_symbolic_isl_map_from_tuples( + tuple_pairs_with_domains, + space, + unused_param_name, + statement_var_name, + ): + """Return an :class:`islpy.Map` constructed using the provided space, + mapping input->output tuples provided in `tuple_pairs_with_domains`, + with each set of tuple variables constrained by the domains provided. + + .. arg tuple_pairs_with_domains: A :class:`list` with each element being + a tuple of the form `((tup_in, tup_out), domain)`. + `tup_in` and `tup_out` are tuples containing elements of type + :class:`int` and :class:`str` representing values for the + input and output dimensions in `space`, and `domain` is a + :class:`islpy.Set` constraining variable bounds. + + .. arg space: A :class:`islpy.Space` to be used to create the map. + + .. arg unused_param_name: A :class:`str` that specifies the name of a + dummy isl parameter assigned to variables in domain elements of the + isl map that represent inames unused in a particular statement + instance. An element in the domain of this map may + represent a statement instance that does not lie within iname x, but + will still need to assign a value to the x domain variable. In this + case, the parameter unused_param_name is is assigned to x. This + situation is detected when a name present in `in_` dimension of + the space is not present in a particular domain. + + .. arg statement_var_name: A :class:`str` specifying the name of the + isl variable used to represent the unique :class:`int` statement id. + + .. return: A :class:`islpy.Map` constructed using the provided space + as follows. For each `((tup_in, tup_out), domain)` in + `tuple_pairs_with_domains`, map + `(tup_in)->(tup_out) : domain`, where `tup_in` and `tup_out` are + numeric or symbolic values assigned to the input and output + dimension variables in `space`, and `domain` specifies constraints + on these values. Any space `in_` dimension variable not + constrained by `domain` is assigned `unused_param_name`. + + """ + + # TODO clarify this with more comments + # TODO allow None for domains + + dim_type = isl.dim_type + + #param_names = space.get_var_names(isl.dim_type.param) + space_out_names = space.get_var_names(dim_type.out) + space_in_names = space.get_var_names(isl.dim_type.in_) + + islvars = get_islvars_from_space(space) + + # loop through pairs and create a set that will later be converted to a map + + # initialize set with constraint that is always false + #constraints_set = islvars[0].eq_set(islvars[0] + 1) + all_maps = [] + for (tup_in, tup_out), dom in tuple_pairs_with_domains: + + # initialize constraint with true + constraint = islvars[0].eq_set(islvars[0]) + + # set values for 'in' dimension using tuple vals + assert len(tup_in) == len(space_in_names) + for dim_name, val_in in zip(space_in_names, tup_in): + if isinstance(val_in, int): + constraint = constraint \ + & islvars[dim_name].eq_set(islvars[0]+val_in) + else: + constraint = constraint \ + & islvars[dim_name].eq_set(islvars[val_in]) + + # TODO we probably shouldn't rely on dom + # here for determing where to set inames equal to dummy vars, + # should instead determine before in LexSchedule and pass info in + dom_var_names = dom.get_var_names(dim_type.set) + if not set( + [var for var in tup_out if not isinstance(var, int)] + ).issubset(set(dom_var_names)): + assert False + unused_inames = set(space_in_names) \ + - set(dom_var_names) - set([statement_var_name]) + # TODO find another way to determine which inames should be unused and + # make an assertion to double check this + for unused_iname in unused_inames: + constraint = constraint & islvars[unused_iname].eq_set( + islvars[unused_param_name]) + + # set values for 'out' dimension using tuple vals + assert len(tup_out) == len(space_out_names) + for dim_name, val_out in zip(space_out_names, tup_out): + if isinstance(val_out, int): + constraint = constraint \ + & islvars[dim_name].eq_set(islvars[0]+val_out) + else: + constraint = constraint \ + & islvars[dim_name].eq_set(islvars[val_out]) + + # convert set to map by moving dimensions around + map_from_set = isl.Map.from_domain(constraint) + map_from_set = map_from_set.move_dims( + dim_type.out, 0, dim_type.in_, + len(space_in_names), len(space_out_names)) + + assert space_in_names == map_from_set.get_var_names( + isl.dim_type.in_) + + # if there are any dimensions in dom that are missing from + # map_from_set, we have a problem I think? + # (assertion checks this in add_missing... + dom_with_all_inames = reorder_dims_by_name( + dom, isl.dim_type.set, + space_in_names, + add_missing=True, + new_names_are_permutation_only=False, + ) + + # intersect domain with this map + all_maps.append( + map_from_set.intersect_domain(dom_with_all_inames)) + + return _union_of_isl_sets_or_maps(all_maps) + + +def set_all_isl_space_names( + isl_space, param_names=None, in_names=None, out_names=None): + """Return a copy of `isl_space` with the specified dimension names. + If no names are provided, use `p0, p1, ...` for parameters, + `i0, i1, ...`, for in_ dimensions, and `o0, o1, ...` for out + dimensions. + + """ + + new_space = isl_space.copy() + dim_type = isl.dim_type + if param_names: + for i, p in enumerate(param_names): + new_space = new_space.set_dim_name(dim_type.param, i, p) + else: + for i in range(len(isl_space.get_var_names(dim_type.param))): + new_space = new_space.set_dim_name(dim_type.param, i, "p%d" % (i)) + if in_names: + for i, p in enumerate(in_names): + new_space = new_space.set_dim_name(dim_type.in_, i, p) + else: + for i in range(len(isl_space.get_var_names(dim_type.in_))): + new_space = new_space.set_dim_name(dim_type.in_, i, "i%d" % (i)) + if out_names: + for i, p in enumerate(out_names): + new_space = new_space.set_dim_name(dim_type.out, i, p) + else: + for i in range(len(isl_space.get_var_names(dim_type.out))): + new_space = new_space.set_dim_name(dim_type.out, i, "o%d" % (i)) + return new_space + + +def get_isl_space(param_names, in_names, out_names): + """Return an :class:`islpy.Space` with the specified dimension names. + """ + + space = isl.Space.alloc( + isl.DEFAULT_CONTEXT, len(param_names), len(in_names), len(out_names)) + return set_all_isl_space_names( + space, param_names=param_names, in_names=in_names, out_names=out_names) + + +def get_concurrent_inames(knl): + from loopy.kernel.data import ConcurrentTag + conc_inames = set() + non_conc_inames = set() + + all_inames = knl.all_inames() + for iname in all_inames: + if knl.iname_tags_of_type(iname, ConcurrentTag): + conc_inames.add(iname) + else: + non_conc_inames.add(iname) + + return conc_inames, all_inames-conc_inames + + +def _get_insn_id_from_sched_item(sched_item): + # TODO could use loopy's sched_item_to_insn_id() + from loopy.schedule import Barrier + if isinstance(sched_item, Barrier): + return sched_item.originating_insn_id + else: + return sched_item.insn_id + + +# TODO for better performance, could combine these funcs so we don't +# loop over schedule more than once +def get_all_nonconcurrent_insn_iname_subsets( + knl, exclude_empty=False, non_conc_inames=None): + """Return a :class:`set` of every unique subset of non-concurrent + inames used in an instruction in a :class:`loopy.LoopKernel`. + + .. arg knl: A :class:`loopy.LoopKernel`. + + .. arg exclude_empty: A :class:`bool` specifying whether to + exclude the empty set. + + .. arg non_conc_inames: A :class:`set` of non-concurrent inames + which may be provided if already known. + + .. return: A :class:`set` of every unique subset of non-concurrent + inames used in any instruction in a :class:`loopy.LoopKernel`. + + """ + + if non_conc_inames is None: + _, non_conc_inames = get_concurrent_inames(knl) + + iname_subsets = set() + for insn in knl.instructions: + iname_subsets.add(insn.within_inames & non_conc_inames) + + if exclude_empty: + iname_subsets.discard(frozenset()) + + return iname_subsets + + +def get_sched_item_ids_within_inames(knl, inames): + sched_item_ids = set() + for insn in knl.instructions: + if inames.issubset(insn.within_inames): + sched_item_ids.add(insn.id) + return sched_item_ids + + +# TODO use yield to clean this up +# TODO use topological sort from loopy, then find longest path in dag +def _generate_orderings_starting_w_prefix( + allowed_after_dict, orderings, required_length=None, + start_prefix=(), return_first_found=False): + # alowed_after_dict = {str: set(str)} + # start prefix = tuple(str) + # orderings = set + if start_prefix: + next_items = allowed_after_dict[start_prefix[-1]]-set(start_prefix) + else: + next_items = allowed_after_dict.keys() + + if required_length: + if len(start_prefix) == required_length: + orderings.add(start_prefix) + if return_first_found: + return + else: + orderings.add(start_prefix) + if return_first_found: + return + + # return if no more items left + if not next_items: + return + + for next_item in next_items: + new_prefix = start_prefix + (next_item,) + _generate_orderings_starting_w_prefix( + allowed_after_dict, + orderings, + required_length=required_length, + start_prefix=new_prefix, + return_first_found=return_first_found, + ) + if return_first_found and orderings: + return + return + + +def get_orderings_of_length_n( + allowed_after_dict, required_length, return_first_found=False): + """Return all orderings found in tree represented by `allowed_after_dict`. + + .. arg allowed_after_dict: A :class:`dict` mapping each :class:`string` + names to a :class:`set` of names that are allowed to come after + that name. + + .. arg required_length: A :class:`int` representing the length required + for all orderings. Orderings not matching the required length will + not be returned. + + .. arg return_first_found: A :class:`bool` specifying whether to return + the first valid ordering found. + + .. return: A :class:`set` of all orderings that are *explicitly* allowed + by the tree represented by `allowed_after_dict`. I.e., if we know + a->b and c->b, we don't know enough to return a->c->b. Note that + if the set for a dict key is empty, nothing is allowed to come after. + + """ + + orderings = set() + _generate_orderings_starting_w_prefix( + allowed_after_dict, + orderings, + required_length=required_length, + start_prefix=(), + return_first_found=return_first_found, + ) + return orderings + + +def create_graph_from_pairs(before_after_pairs): + # create key for every before + graph = dict([(before, set()) for before, _ in before_after_pairs]) + for before, after in before_after_pairs: + graph[before] = graph[before] | set([after, ]) + return graph + + +# only used for example purposes: + + +def create_explicit_map_from_tuples(tuple_pairs, space): + """Return a :class:`islpy.Map` in :class:`islpy.Space` space + mapping tup_in->tup_out for each `(tup_in, tup_out)` pair + in `tuple_pairs`, where `tup_in` and `tup_out` are + tuples of :class:`int` values to be assigned to the + corresponding dimension variables in `space`. + + """ + + dim_type = isl.dim_type + individual_maps = [] + + for tup_in, tup_out in tuple_pairs: + constraints = [] + for i, val_in in enumerate(tup_in): + constraints.append( + isl.Constraint.equality_alloc(space) + .set_coefficient_val(dim_type.in_, i, 1) + .set_constant_val(-1*val_in)) + for i, val_out in enumerate(tup_out): + constraints.append( + isl.Constraint.equality_alloc(space) + .set_coefficient_val(dim_type.out, i, 1) + .set_constant_val(-1*val_out)) + individual_maps.append( + isl.Map.universe(space).add_constraints(constraints)) + + union_map = individual_maps[0] + for m in individual_maps[1:]: + union_map = union_map.union(m) + + return union_map diff --git a/loopy/schedule/schedule_checker/schedule.py b/loopy/schedule/schedule_checker/schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..4c99f45ce9b24e4e663e9a0d7a9ce6303b5a94db --- /dev/null +++ b/loopy/schedule/schedule_checker/schedule.py @@ -0,0 +1,482 @@ +import islpy as isl + + +class LexScheduleStatement(object): + """A representation of a Loopy statement. + + .. attribute:: insn_id + + A :class:`str` specifying the instruction id. + + .. attribute:: int_id + + A :class:`int` uniquely identifying the instruction. + + .. attribute:: within_inames + + A :class:`list` of :class:`str` inames identifying the loops within + which this statement will be executed. + + """ + + def __init__( + self, + insn_id, # loopy insn id + int_id=None, # sid int (statement id within LexSchedule) + within_inames=None, # [string, ] + ): + self.insn_id = insn_id # string + self.int_id = int_id + self.within_inames = within_inames + + def __eq__(self, other): + return ( + self.insn_id == other.insn_id + and self.int_id == other.int_id + and self.within_inames == other.within_inames + ) + + def __hash__(self): + return hash(repr(self)) + + def update_persistent_hash(self, key_hash, key_builder): + """Custom hash computation function for use with + :class:`pytools.persistent_dict.PersistentDict`. + """ + + key_builder.rec(key_hash, self.insn_id) + key_builder.rec(key_hash, self.int_id) + key_builder.rec(key_hash, self.within_inames) + + def __str__(self): + if self.int_id: + int_id = ":%d" % (self.int_id) + else: + int_id = "" + if self.within_inames: + within_inames = " {%s}" % (",".join(self.within_inames)) + else: + within_inames = "" + return "%s%s%s" % ( + self.insn_id, int_id, within_inames) + + +class LexScheduleStatementInstance(object): + """A representation of a Loopy statement instance. + + .. attribute:: stmt + + A :class:`LexScheduleStatement`. + + .. attribute:: lex_pt + + A list of :class:`int` or as :class:`str` Loopy inames representing + a point or set of points in a lexicographic ordering. + + """ + + def __init__( + self, + stmt, # a LexScheduleStatement + lex_pt, # [string/int, ] + ): + self.stmt = stmt + self.lex_pt = lex_pt + + def __str__(self): + return "{%s, %s}" % (self.stmt, self.lex_pt) + + +class LexSchedule(object): + """A program ordering represented as a mapping from statement + instances to points in a lexicographic ordering. + + .. attribute:: stmt_instance_before + + A :class:`LexScheduleStatementInstance` describing the dependee + statement's order relative to the depender statment by mapping + a statement to a point or set of points in a lexicographic + ordering. Points in lexicographic ordering are represented as + a list of :class:`int` or as :class:`str` Loopy inames. + + .. attribute:: stmt_instance_after + + A :class:`LexScheduleStatementInstance` describing the depender + statement's order relative to the dependee statment by mapping + a statement to a point or set of points in a lexicographic + ordering. Points in lexicographic ordering are represented as + a list of :class:`int` or as :class:`str` Loopy inames. + + .. attribute:: unused_param_name + + A :class:`str` that specifies the name of a dummy isl parameter + assigned to variables in domain elements of the isl map that + represent inames unused in a particular statement instance. + The domain space of the generated isl map will have a dimension + for every iname used in any statement instance found in the + program ordering. An element in the domain of this map may + represent a statement instance that does not lie within + iname x, but will still need to assign a value to the x domain + variable. In this case, the parameter unused_param_name is + is assigned to x. + + .. attribute:: statement_var_name + + A :class:`str` specifying the name of the isl variable used + to represent the unique :class:`int` statement id. + + .. attribute:: lex_var_prefix + + A :class:`str` specifying the prefix to be used for the variables + representing the dimensions in the lexicographic ordering. E.g., + a prefix of "lex" might yield variables "lex0", "lex1", "lex2". + + """ + + unused_param_name = "unused" + statement_var_name = "statement" + lex_var_prefix = "l" + + def __init__( + self, + knl, + sched_items_ordered, + before_insn_id, + after_insn_id, + prohibited_var_names=[], + ): + """ + :arg knl: A :class:`LoopKernel` whose schedule items will be + described by this :class:`LexSchedule`. + + :arg sched_items_ordered: A list of :class:`ScheduleItem` whose + order will be described by this :class:`LexSchedule`. + + :arg before_insn_id: A :class:`str` instruction id specifying + the dependee in this pair of instructions. + + :arg after_insn_id: A :class:`str` instruction id specifying + the depender in this pair of instructions. + + :arg prohibited_var_names: A list of :class:`str` variable names + that may not be used as the statement variable name (e.g., + because they are already being used as inames). + + """ + + # LexScheduleStatements + self.stmt_instance_before = None + self.stmt_instance_after = None + + # make sure we don't have an iname name conflict + assert not any( + iname == self.statement_var_name for iname in prohibited_var_names) + assert not any( + iname == self.unused_param_name for iname in prohibited_var_names) + + from loopy.schedule import (EnterLoop, LeaveLoop, Barrier, RunInstruction) + from loopy.kernel.data import ConcurrentTag + + # go through sched_items_ordered and generate self.lex_schedule + + # keep track of the next point in our lexicographic ordering + # initially this as a 1-d point with value 0 + next_insn_lex_pt = [0] + next_sid = 0 + for sched_item in sched_items_ordered: + if isinstance(sched_item, EnterLoop): + iname = sched_item.iname + if knl.iname_tags_of_type(iname, ConcurrentTag): + # In the future, this should be unnecessary because there + # won't be any inames with ConcurrentTags in the loopy sched + from warnings import warn + warn( + "LexSchedule.__init__: Encountered EnterLoop for iname %s " + "with ConcurrentTag(s) in schedule for kernel %s. " + "Ignoring this loop." % (iname, knl.name)) + continue + + # if the schedule is empty, this is the first schedule item, so + # don't increment lex dim val enumerating items in current block, + # otherwise, this loop is next item in current code block, so + # increment lex dim val enumerating items in current code block + if self.stmt_instance_before or self.stmt_instance_after: + # (if either statement has been set) + # this lex value will correspond to everything inside this loop + # we will add new lex dimensions to enuerate items inside loop + next_insn_lex_pt[-1] = next_insn_lex_pt[-1]+1 + + # upon entering a loop, we enter a new (deeper) code block, so + # add one lex dimension for the loop variable, and + # add a second lex dim to enumerate code blocks within the new loop + next_insn_lex_pt.append(iname) + next_insn_lex_pt.append(0) + elif isinstance(sched_item, LeaveLoop): + if knl.iname_tags_of_type(sched_item.iname, ConcurrentTag): + # In the future, this should be unnecessary because there + # won't be any inames with ConcurrentTags in the loopy sched + continue + # upon leaving a loop, + # pop lex dimension for enumerating code blocks within this loop, and + # pop lex dimension for the loop variable, and + # increment lex dim val enumerating items in current code block + next_insn_lex_pt.pop() + next_insn_lex_pt.pop() + next_insn_lex_pt[-1] = next_insn_lex_pt[-1]+1 + # if we didn't add any statements while in this loop, we might + # sometimes be able to skip increment, but it's not hurting anything + # TODO might not need this increment period? + elif isinstance(sched_item, (RunInstruction, Barrier)): + from schedule_checker.sched_check_utils import ( + _get_insn_id_from_sched_item, + ) + lp_insn_id = _get_insn_id_from_sched_item(sched_item) + if lp_insn_id is None: + # TODO make sure it's okay to ignore barriers without id + # (because they'll never be part of a dependency?) + # matmul example has barrier that fails this assertion... + # assert sched_item.originating_insn_id is not None + continue + + # if include_only_insn_ids list was passed, + # only process insns found in list, + # otherwise process all instructions + if lp_insn_id == before_insn_id and lp_insn_id == after_insn_id: + # add before sched item + self.stmt_instance_before = LexScheduleStatementInstance( + LexScheduleStatement( + insn_id=lp_insn_id, + int_id=next_sid, # int representing insn + ), + next_insn_lex_pt[:]) + # add after sched item + self.stmt_instance_after = LexScheduleStatementInstance( + LexScheduleStatement( + insn_id=lp_insn_id, + int_id=next_sid, # int representing insn + ), + next_insn_lex_pt[:]) + + # increment lex dim val enumerating items in current code block + next_insn_lex_pt[-1] = next_insn_lex_pt[-1] + 1 + next_sid += 1 + elif lp_insn_id == before_insn_id: + # add before sched item + self.stmt_instance_before = LexScheduleStatementInstance( + LexScheduleStatement( + insn_id=lp_insn_id, + int_id=next_sid, # int representing insn + ), + next_insn_lex_pt[:]) + + # increment lex dim val enumerating items in current code block + next_insn_lex_pt[-1] = next_insn_lex_pt[-1] + 1 + next_sid += 1 + elif lp_insn_id == after_insn_id: + # add after sched item + self.stmt_instance_after = LexScheduleStatementInstance( + LexScheduleStatement( + insn_id=lp_insn_id, + int_id=next_sid, # int representing insn + ), + next_insn_lex_pt[:]) + + # increment lex dim val enumerating items in current code block + next_insn_lex_pt[-1] = next_insn_lex_pt[-1] + 1 + next_sid += 1 + else: + pass + # to save time, stop when we've created both statements + if self.stmt_instance_before and self.stmt_instance_after: + break + + # at this point, lex_schedule may contain lex points missing dimensions, + # the values in these missing dims should be zero, so add them + self.pad_lex_pts_with_zeros() + + def loopy_insn_id_to_lex_sched_id(self): + """Return a dictionary mapping insn_id to int_id, where ``insn_id`` and + ``int_id`` refer to the ``insn_id`` and ``int_id`` attributes of + :class:`LexScheduleStatement`. + """ + return { + self.stmt_instance_before.stmt.insn_id: + self.stmt_instance_before.stmt.int_id, + self.stmt_instance_after.stmt.insn_id: + self.stmt_instance_after.stmt.int_id, + } + + def max_lex_dims(self): + return max([ + len(self.stmt_instance_before.lex_pt), + len(self.stmt_instance_after.lex_pt)]) + + def pad_lex_pts_with_zeros(self): + """Find the maximum number of lexicographic dimensions represented + in the lexicographic ordering, and if any + :class:`LexScheduleStatement` maps to a point in lexicographic + time with fewer dimensions, add a zero for each of the missing + dimensions. + """ + + max_lex_dim = self.max_lex_dims() + self.stmt_instance_before = LexScheduleStatementInstance( + self.stmt_instance_before.stmt, + self.stmt_instance_before.lex_pt[:] + [0]*( + max_lex_dim-len(self.stmt_instance_before.lex_pt)) + ) + self.stmt_instance_after = LexScheduleStatementInstance( + self.stmt_instance_after.stmt, + self.stmt_instance_after.lex_pt[:] + [0]*( + max_lex_dim-len(self.stmt_instance_after.lex_pt)) + ) + + def create_symbolic_isl_maps( + self, + dom_before, + dom_after, + dom_inames_ordered_before=None, + dom_inames_ordered_after=None, + ): + """Create two isl maps representing lex schedule as two mappings + from statement instances to lexicographic time, one for + the dependee and one for the depender. + + .. arg dom_before: A :class:`islpy.BasicSet` representing the + domain for the dependee statement. + + .. arg dom_after: A :class:`islpy.BasicSet` representing the + domain for the dependee statement. + + .. arg dom_inames_ordered_before: A list of :class:`str` + representing the union of inames used in instances of the + dependee statement. ``statement_var_name`` and + ``dom_inames_ordered_before`` are the names of the dims of + the space of the ISL map domain for the dependee. + + .. arg dom_inames_ordered_after: A list of :class:`str` + representing the union of inames used in instances of the + depender statement. ``statement_var_name`` and + ``dom_inames_ordered_after`` are the names of the dims of + the space of the ISL map domain for the depender. + + .. return: A two-tuple containing two :class:`islpy.Map`s + representing the schedule as two mappings + from statement instances to lexicographic time, one for + the dependee and one for the depender. + + """ + + from schedule_checker.sched_check_utils import ( + create_symbolic_isl_map_from_tuples, + add_dims_to_isl_set + ) + + from schedule_checker.sched_check_utils import ( + list_var_names_in_isl_sets, + ) + if dom_inames_ordered_before is None: + dom_inames_ordered_before = list_var_names_in_isl_sets( + [dom_before]) + if dom_inames_ordered_after is None: + dom_inames_ordered_after = list_var_names_in_isl_sets( + [dom_after]) + + # create an isl space + # {('statement', used in >=1 statement domain>) -> + # (lexicographic ordering dims)} + from schedule_checker.sched_check_utils import get_isl_space + params_sched = [self.unused_param_name] + out_names_sched = self.get_lex_var_names() + + in_names_sched_before = [ + self.statement_var_name] + dom_inames_ordered_before[:] + sched_space_before = get_isl_space( + params_sched, in_names_sched_before, out_names_sched) + in_names_sched_after = [ + self.statement_var_name] + dom_inames_ordered_after[:] + sched_space_after = get_isl_space( + params_sched, in_names_sched_after, out_names_sched) + + # Insert 'statement' dim into domain so that its space allows for + # intersection with sched map later + doms_to_intersect_before = [ + add_dims_to_isl_set( + dom_before, isl.dim_type.set, + [self.statement_var_name], 0), + ] + doms_to_intersect_after = [ + add_dims_to_isl_set( + dom_after, isl.dim_type.set, + [self.statement_var_name], 0), + ] + + # Each isl map representing the schedule maps + # statement instances -> lex time + + # Right now, statement tuples consist of single int. + # Add all inames from domains to map domain tuples. + + # create isl map + return ( + create_symbolic_isl_map_from_tuples( + zip( + [( + (self.stmt_instance_before.stmt.int_id,) + + tuple(dom_inames_ordered_before), + self.stmt_instance_before.lex_pt + )], + doms_to_intersect_before + ), + sched_space_before, self.unused_param_name, self.statement_var_name), + create_symbolic_isl_map_from_tuples( + zip( + [( + (self.stmt_instance_after.stmt.int_id,) + + tuple(dom_inames_ordered_after), + self.stmt_instance_after.lex_pt)], + doms_to_intersect_after + ), + sched_space_after, self.unused_param_name, self.statement_var_name) + ) + + def get_lex_var_names(self): + return [self.lex_var_prefix+str(i) + for i in range(self.max_lex_dims())] + + def get_lex_order_map_for_symbolic_sched(self): + """Return an :class:`islpy.BasicMap` that maps each point in a + lexicographic ordering to every point that is + lexocigraphically greater. + """ + + from schedule_checker.lexicographic_order_map import ( + create_lex_order_map, + ) + n_dims = self.max_lex_dims() + return create_lex_order_map( + n_dims, before_names=self.get_lex_var_names()) + + def __nonzero__(self): + return self.__bool__() + + def __eq__(self, other): + return ( + self.stmt_instance_before == other.stmt_instance_before + and self.stmt_instance_after == other.stmt_instance_after) + + def __str__(self): + sched_str = "Before: {\n" + domain_elem = "[%s=%s,]" % ( + self.statement_var_name, + self.stmt_instance_before.stmt.int_id) + sched_str += "%s -> %s;\n" % (domain_elem, self.stmt_instance_before.lex_pt) + sched_str += "}\n" + + sched_str += "After: {\n" + domain_elem += "[%s=%s,]" % ( + self.statement_var_name, + self.stmt_instance_after.stmt.int_id) + sched_str += "%s -> %s;\n" % (domain_elem, self.stmt_instance_after.lex_pt) + sched_str += "}" + return sched_str diff --git a/loopy/schedule/schedule_checker/test/test_invalid_scheds.py b/loopy/schedule/schedule_checker/test/test_invalid_scheds.py new file mode 100644 index 0000000000000000000000000000000000000000..05073502a95ce1f742bc5b41de374a3247a93a85 --- /dev/null +++ b/loopy/schedule/schedule_checker/test/test_invalid_scheds.py @@ -0,0 +1,168 @@ +from __future__ import division, print_function + +__copyright__ = "Copyright (C) 2018 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl + as pytest_generate_tests) +import loopy as lp +from schedule_checker import ( + get_statement_pair_dependency_sets_from_legacy_knl, + check_schedule_validity, +) +from loopy.kernel import KernelState +from loopy import ( + preprocess_kernel, + get_one_scheduled_kernel, +) + + +def test_invalid_prioritiy_detection(): + ref_knl = lp.make_kernel( + [ + "{[h]: 0<=h acc = 0 + for h,i,j,k + acc = acc + h + i + j + k + end + """, + name="priorities", + assumptions="ni,nj,nk,nh >= 1", + lang_version=(2018, 2) + ) + + # no error: + knl0 = lp.prioritize_loops(ref_knl, "h,i") + knl0 = lp.prioritize_loops(ref_knl, "i,j") + knl0 = lp.prioritize_loops(knl0, "j,k") + + unprocessed_knl = knl0.copy() + + deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl( + unprocessed_knl) + if hasattr(lp, "add_dependencies_v2"): + knl0 = lp.add_dependencies_v2(knl0, deps_and_domains) + + # get a schedule to check + if knl0.state < KernelState.PREPROCESSED: + knl0 = preprocess_kernel(knl0) + knl0 = get_one_scheduled_kernel(knl0) + schedule_items = knl0.schedule + + sched_is_valid = check_schedule_validity( + unprocessed_knl, deps_and_domains, schedule_items) + assert sched_is_valid + + # no error: + knl1 = lp.prioritize_loops(ref_knl, "h,i,k") + knl1 = lp.prioritize_loops(knl1, "h,j,k") + + unprocessed_knl = knl1.copy() + + deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl( + unprocessed_knl) + if hasattr(lp, "add_dependencies_v2"): + knl1 = lp.add_dependencies_v2(knl1, deps_and_domains) + + # get a schedule to check + if knl1.state < KernelState.PREPROCESSED: + knl1 = preprocess_kernel(knl1) + knl1 = get_one_scheduled_kernel(knl1) + schedule_items = knl1.schedule + + sched_is_valid = check_schedule_validity( + unprocessed_knl, deps_and_domains, schedule_items) + assert sched_is_valid + + # error (cycle): + knl2 = lp.prioritize_loops(ref_knl, "h,i,j") + knl2 = lp.prioritize_loops(knl2, "j,k") + try: + if hasattr(lp, "constrain_loop_nesting"): + knl2 = lp.constrain_loop_nesting(knl2, "k,i") + else: + knl2 = lp.prioritize_loops(knl2, "k,i") + + unprocessed_knl = knl2.copy() + + deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl( + unprocessed_knl) + + # get a schedule to check + if knl2.state < KernelState.PREPROCESSED: + knl2 = preprocess_kernel(knl2) + knl2 = get_one_scheduled_kernel(knl2) + schedule_items = knl2.schedule + + sched_is_valid = check_schedule_validity( + unprocessed_knl, deps_and_domains, schedule_items) + # should raise error + assert False + except ValueError as e: + if hasattr(lp, "constrain_loop_nesting"): + assert "cycle detected" in str(e) + else: + assert "invalid priorities" in str(e) + + # error (inconsistent priorities): + knl3 = lp.prioritize_loops(ref_knl, "h,i,j,k") + try: + if hasattr(lp, "constrain_loop_nesting"): + knl3 = lp.constrain_loop_nesting(knl3, "h,j,i,k") + else: + knl3 = lp.prioritize_loops(knl3, "h,j,i,k") + + unprocessed_knl = knl3.copy() + + deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl( + unprocessed_knl) + + # get a schedule to check + if knl3.state < KernelState.PREPROCESSED: + knl3 = preprocess_kernel(knl3) + knl3 = get_one_scheduled_kernel(knl3) + schedule_items = knl3.schedule + + sched_is_valid = check_schedule_validity( + unprocessed_knl, deps_and_domains, schedule_items) + # should raise error + assert False + except ValueError as e: + if hasattr(lp, "constrain_loop_nesting"): + assert "cycle detected" in str(e) + else: + assert "invalid priorities" in str(e) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) diff --git a/loopy/schedule/schedule_checker/test/test_valid_scheds.py b/loopy/schedule/schedule_checker/test/test_valid_scheds.py new file mode 100644 index 0000000000000000000000000000000000000000..f12211dcee7ae09876f0a05d6dff3d24698f83de --- /dev/null +++ b/loopy/schedule/schedule_checker/test/test_valid_scheds.py @@ -0,0 +1,354 @@ +from __future__ import division, print_function + +__copyright__ = "Copyright (C) 2018 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl + as pytest_generate_tests) +import loopy as lp +import numpy as np +from schedule_checker import ( + get_statement_pair_dependency_sets_from_legacy_knl, + check_schedule_validity, +) +from loopy.kernel import KernelState +from loopy import ( + preprocess_kernel, + get_one_scheduled_kernel, +) + + +def test_loop_prioritization(): + knl = lp.make_kernel( + [ + "{[i]: 0<=itemp = b[i,k] {id=insn_a} + end + for j + a[i,j] = temp + 1 {id=insn_b,dep=insn_a} + c[i,j] = d[i,j] {id=insn_c} + end + end + for t + e[t] = f[t] {id=insn_d} + end + """, + name="example", + assumptions="pi,pj,pk,pt >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes( + knl, + {"b": np.float32, "d": np.float32, "f": np.float32}) + knl = lp.prioritize_loops(knl, "i,k") + knl = lp.prioritize_loops(knl, "i,j") + + unprocessed_knl = knl.copy() + + deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl( + unprocessed_knl) + if hasattr(lp, "add_dependencies_v2"): + knl = lp.add_dependencies_v2(knl, deps_and_domains) + + # get a schedule to check + if knl.state < KernelState.PREPROCESSED: + knl = preprocess_kernel(knl) + knl = get_one_scheduled_kernel(knl) + schedule_items = knl.schedule + + sched_is_valid = check_schedule_validity( + unprocessed_knl, deps_and_domains, schedule_items) + assert sched_is_valid + + +def test_matmul(): + bsize = 16 + knl = lp.make_kernel( + "{[i,k,j]: 0<=i {[i,j]: 0<=i {[i]: 0<=i xi = qpts[1, i2] + <> s = 1-xi + <> r = xi/s + <> aind = 0 {id=aind_init} + for alpha1 + <> w = s**(deg-alpha1) {id=init_w} + for alpha2 + tmp[el,alpha1,i2] = tmp[el,alpha1,i2] + w * coeffs[aind] \ + {id=write_tmp,dep=init_w:aind_init} + w = w * r * ( deg - alpha1 - alpha2 ) / (1 + alpha2) \ + {id=update_w,dep=init_w:write_tmp} + aind = aind + 1 \ + {id=aind_incr,dep=aind_init:write_tmp:update_w} + end + end + end + """, + [lp.GlobalArg("coeffs", None, shape=None), "..."], + name="stroud_bernstein_orig", assumptions="deg>=0 and nels>=1") + knl = lp.add_and_infer_dtypes(knl, + dict(coeffs=np.float32, qpts=np.int32)) + knl = lp.fix_parameters(knl, nqp1d=7, deg=4) + knl = lp.split_iname(knl, "el", 16, inner_tag="l.0") + knl = lp.split_iname(knl, "el_outer", 2, outer_tag="g.0", + inner_tag="ilp", slabs=(0, 1)) + knl = lp.tag_inames(knl, dict(i2="l.1", alpha1="unr", alpha2="unr")) + + unprocessed_knl = knl.copy() + + deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl( + unprocessed_knl) + if hasattr(lp, "add_dependencies_v2"): + knl = lp.add_dependencies_v2(knl, deps_and_domains) + + # get a schedule to check + if knl.state < KernelState.PREPROCESSED: + knl = preprocess_kernel(knl) + knl = get_one_scheduled_kernel(knl) + schedule_items = knl.schedule + + sched_is_valid = check_schedule_validity( + unprocessed_knl, deps_and_domains, schedule_items) + assert sched_is_valid + + +def test_nop(): + knl = lp.make_kernel( + [ + "{[b]: b_start<=b c_end = 2 + for c + ... nop + end + end + """, + "...", + seq_dependencies=True) + knl = lp.fix_parameters(knl, dim=3) + + unprocessed_knl = knl.copy() + + deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl( + unprocessed_knl) + if hasattr(lp, "add_dependencies_v2"): + knl = lp.add_dependencies_v2(knl, deps_and_domains) + + # get a schedule to check + if knl.state < KernelState.PREPROCESSED: + knl = preprocess_kernel(knl) + knl = get_one_scheduled_kernel(knl) + schedule_items = knl.schedule + + sched_is_valid = check_schedule_validity( + unprocessed_knl, deps_and_domains, schedule_items) + assert sched_is_valid + + +def test_multi_domain(): + knl = lp.make_kernel( + [ + "{[i]: 0<=iacc = 0 {id=insn0} + for j + for k + acc = acc + j + k {id=insn1,dep=insn0} + end + end + end + end + """, + name="nest_multi_dom", + assumptions="ni,nj,nk,nx >= 1", + lang_version=(2018, 2) + ) + knl = lp.prioritize_loops(knl, "x,xx,i") + knl = lp.prioritize_loops(knl, "i,j") + knl = lp.prioritize_loops(knl, "j,k") + + unprocessed_knl = knl.copy() + + deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl( + unprocessed_knl) + if hasattr(lp, "add_dependencies_v2"): + knl = lp.add_dependencies_v2(knl, deps_and_domains) + + # get a schedule to check + if knl.state < KernelState.PREPROCESSED: + knl = preprocess_kernel(knl) + knl = get_one_scheduled_kernel(knl) + schedule_items = knl.schedule + + sched_is_valid = check_schedule_validity( + unprocessed_knl, deps_and_domains, schedule_items) + assert sched_is_valid + + +def test_loop_carried_deps(): + knl = lp.make_kernel( + "{[i]: 0<=iacc0 = 0 {id=insn0} + for i + acc0 = acc0 + i {id=insn1,dep=insn0} + <>acc2 = acc0 + i {id=insn2,dep=insn1} + <>acc3 = acc2 + i {id=insn3,dep=insn2} + <>acc4 = acc0 + i {id=insn4,dep=insn1} + end + """, + name="loop_carried_deps", + assumptions="n >= 1", + lang_version=(2018, 2) + ) + + unprocessed_knl = knl.copy() + + deps_and_domains = get_statement_pair_dependency_sets_from_legacy_knl( + unprocessed_knl) + if hasattr(lp, "add_dependencies_v2"): + knl = lp.add_dependencies_v2(knl, deps_and_domains) + + # get a schedule to check + if knl.state < KernelState.PREPROCESSED: + knl = preprocess_kernel(knl) + knl = get_one_scheduled_kernel(knl) + schedule_items = knl.schedule + + sched_is_valid = check_schedule_validity( + unprocessed_knl, deps_and_domains, schedule_items) + assert sched_is_valid + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) diff --git a/loopy/schedule/schedule_checker/version.py b/loopy/schedule/schedule_checker/version.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a75f587230d6b2226de9bfd3a9689a440a843b --- /dev/null +++ b/loopy/schedule/schedule_checker/version.py @@ -0,0 +1 @@ +VERSION_TEXT = "0.1"