diff --git a/loopy/check.py b/loopy/check.py index 616e75e33e4380f60512bb1571ffb0f41022286f..146391bf2533e35e7bc2f2091c9968fb5b321b6f 100644 --- a/loopy/check.py +++ b/loopy/check.py @@ -60,7 +60,6 @@ def check_identifiers_in_subst_rules(knl): # {{{ sanity checks run pre-scheduling - # FIXME: Replace with an enum. See # https://gitlab.tiker.net/inducer/loopy/issues/85 VALID_NOSYNC_SCOPES = frozenset(["local", "global", "any"]) @@ -489,7 +488,8 @@ def _check_variable_access_ordered_inner(kernel): # Check even for PRIVATE scope, to ensure intentional program order. - from loopy.symbolic import do_access_ranges_overlap_conservative + from loopy.symbolic import AccessRangeOverlapChecker + overlap_checker = AccessRangeOverlapChecker(kernel) for writer_id in writers: for other_id in readers | writers: @@ -516,10 +516,9 @@ def _check_variable_access_ordered_inner(kernel): or other_id in unaliased_readers)) # Do not enforce ordering for disjoint access ranges - if (not is_relationship_by_aliasing - and not do_access_ranges_overlap_conservative( - kernel, writer_id, "w", other_id, "any", - name)): + if (not is_relationship_by_aliasing and not + overlap_checker.do_access_ranges_overlap_conservative( + writer_id, "w", other_id, "any", name)): continue # Do not enforce ordering for aliasing-based relationships diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py index 9725d7782e8c70bae2a931bf1e64f4e3ef13b148..2c9964b11ad30c5b6b6ffacacc7f67ef239b50a1 100644 --- a/loopy/schedule/__init__.py +++ b/loopy/schedule/__init__.py @@ -1435,6 +1435,9 @@ class DependencyTracker(object): self.reverse = reverse self.var_kind = var_kind + from loopy.symbolic import AccessRangeOverlapChecker + self.overlap_checker = AccessRangeOverlapChecker(kernel) + if var_kind == "local": self.relevant_vars = kernel.local_var_names() elif var_kind == "global": @@ -1555,10 +1558,9 @@ class DependencyTracker(object): if src_race_vars == tgt_race_vars and len(src_race_vars) == 1: race_var, = src_race_vars - from loopy.symbolic import do_access_ranges_overlap_conservative - if not do_access_ranges_overlap_conservative( - self.kernel, target.id, tgt_dir, - source_id, src_dir, race_var): + if not ( + self.overlap_checker.do_access_ranges_overlap_conservative( + target.id, tgt_dir, source_id, src_dir, race_var)): continue yield DependencyRecord( diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 0635c0e535c898da4604e7fd5c65dd087c7d37c4..0cc8f4ba6a1531d748bd90492f570dbb563d962d 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -1686,8 +1686,9 @@ class BatchedAccessRangeMapper(WalkMapper): def __init__(self, kernel, var_names, overestimate=None): self.kernel = kernel self.var_names = set(var_names) - self.access_ranges = dict((arg, None) for arg in var_names) - self.bad_subscripts = dict((arg, []) for arg in var_names) + from collections import defaultdict + self.access_ranges = defaultdict(lambda: None) + self.bad_subscripts = defaultdict(list) if overestimate is None: overestimate = False @@ -1744,6 +1745,18 @@ class BatchedAccessRangeMapper(WalkMapper): class AccessRangeMapper(object): + """**IMPORTANT** + + Using this class *will likely* lead to performance bottlenecks. + + To avoid performance issues, rewrite your code to use + BatchedAccessRangeMapper if at all possible. + + For *n* variables and *m* expressions, calling this class to compute the + access ranges will take *O(mn)* time for traversing the expressions. + + BatchedAccessRangeMapper does the same traversal in *O(m + n)* time. + """ def __init__(self, kernel, var_name, overestimate=None): self.var_name = var_name @@ -1764,69 +1777,77 @@ class AccessRangeMapper(object): # }}} -# {{{ do_access_ranges_overlap_conservative +# {{{ check if access ranges overlap -def _get_access_range_conservative(kernel, insn_id, access_dir, var_name): - insn = kernel.id_to_insn[insn_id] - from loopy.kernel.instruction import MultiAssignmentBase +class AccessRangeOverlapChecker(object): + """Used for checking for overlap between access ranges of instructions.""" - assert access_dir in ["w", "any"] + def __init__(self, kernel): + self.kernel = kernel + self.vars = kernel.get_written_variables() | kernel.get_read_variables() + + @memoize_method + def _get_access_ranges(self, insn_id, access_dir): + insn = self.kernel.id_to_insn[insn_id] - if not isinstance(insn, MultiAssignmentBase): + exprs = list(insn.assignees) if access_dir == "any": - return var_name in insn.dependency_names() - else: - return var_name in insn.write_dependency_names() + exprs.append(insn.expression) + exprs.extend(insn.predicates) - exprs = list(insn.assignees) - if access_dir == "any": - exprs.append(insn.expression) - exprs.extend(insn.predicates) + from collections import defaultdict + aranges = defaultdict(lambda: False) - arange = False - for expr in exprs: - arm = AccessRangeMapper(kernel, var_name, overestimate=True) - arm(expr, kernel.insn_inames(insn)) + arm = BatchedAccessRangeMapper(self.kernel, self.vars, overestimate=True) - if arm.bad_subscripts: - return True + for expr in exprs: + arm(expr, self.kernel.insn_inames(insn)) - expr_arange = arm.access_range - if expr_arange is None: - continue + for name, arange in six.iteritems(arm.access_ranges): + if arm.bad_subscripts[name]: + aranges[name] = True + continue + aranges[name] = arange - if arange is False: - arange = expr_arange - else: - arange = arange | expr_arange + return aranges - return arange + def _get_access_range_for_var(self, insn_id, access_dir, var_name): + assert access_dir in ["w", "any"] + insn = self.kernel.id_to_insn[insn_id] + # Access range checks only apply to assignment-style instructions. For + # non-assignments, we rely on read/write dependency information. + from loopy.kernel.instruction import MultiAssignmentBase + if not isinstance(insn, MultiAssignmentBase): + if access_dir == "any": + return var_name in insn.dependency_names() + else: + return var_name in insn.write_dependency_names() -def do_access_ranges_overlap_conservative( - kernel, insn1_id, insn1_dir, insn2_id, insn2_dir, var_name): - """Determine whether the access ranges to *var_name* in the two - given instructions overlap. This determination is made 'conservatively', - i.e. if precise information is unavailable, it is concluded that the - ranges overlap. + return self._get_access_ranges(insn_id, access_dir)[var_name] - :arg insn1_dir: either ``"w"`` or ``"any"``, to indicate which - type of access is desired--writing or any - :arg insn2_dir: either ``"w"`` or ``"any"`` - :returns: a :class:`bool` - """ + def do_access_ranges_overlap_conservative( + self, insn1, insn1_dir, insn2, insn2_dir, var_name): + """Determine whether the access ranges to *var_name* in the two + given instructions overlap. This determination is made 'conservatively', + i.e. if precise information is unavailable, it is concluded that the + ranges overlap. + + :arg insn1_dir: either ``"w"`` or ``"any"``, to indicate which + type of access is desired--writing or any + :arg insn2_dir: either ``"w"`` or ``"any"`` + :returns: a :class:`bool` + """ - insn1_arange = _get_access_range_conservative( - kernel, insn1_id, insn1_dir, var_name) - insn2_arange = _get_access_range_conservative( - kernel, insn2_id, insn2_dir, var_name) + insn1_arange = self._get_access_range_for_var(insn1, insn1_dir, var_name) + insn2_arange = self._get_access_range_for_var(insn2, insn2_dir, var_name) - if insn1_arange is False or insn2_arange is False: - return False - if insn1_arange is True or insn2_arange is True: - return True + if insn1_arange is False or insn2_arange is False: + return False + if insn1_arange is True or insn2_arange is True: + return True - return not (insn1_arange & insn2_arange).is_empty() + return not (insn1_arange & insn2_arange).is_empty() # }}}