diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 2f19e7e3e850daf8f982cbc3a1c3194ddbec53fd..db7792cce55c9b7851850c3059e821fc574c3270 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -746,7 +746,8 @@ def find_idempotence(kernel): for var in reads_map[insn.id] for writer_id in writer_map.get(var, set())) - # Find SCCs of dep_graph. + # Find SCCs of dep_graph. These are used for checking if the instruction is + # in a dependency cycle. from loopy.tools import compute_sccs sccs = dict((item, scc) diff --git a/loopy/tools.py b/loopy/tools.py index 7952b875a6b1ce4a7a8cac587eaf2abc16eed8d3..01d0641fc25c11a092185125604613819a0293ca 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -291,48 +291,49 @@ def compute_sccs(graph): while to_search: top = next(iter(to_search)) - stack = [top] + call_stack = [(top, iter(graph[top]), None)] + visit_stack = [] visiting = set() scc = [] - while stack: - top = stack[-1] + while call_stack: + top, children, last_popped_child = call_stack.pop() - if top in visiting: - for child in graph[top]: - if child in visiting: - # Update SCC root. - scc_root[top] = min( - scc_root[top], - scc_root[child]) - - # Add to the current SCC and check if we're the root. - scc.append(top) - - if visit_order[top] == scc_root[top]: - sccs.append(scc) - scc = [] - - to_search.discard(top) - visiting.remove(top) - - if top in visit_order: - stack.pop() - else: + if top not in visiting: + # Unvisited: mark as visited, initialize SCC root. count = len(visit_order) + visit_stack.append(top) visit_order[top] = count scc_root[top] = count visiting.add(top) + to_search.discard(top) - for child in graph[top]: - if child in visiting: - # Update SCC root. - scc_root[top] = min( - scc_root[top], - visit_order[child]) - elif child not in visit_order: - stack.append(child) + # Returned from a recursion, update SCC. + if last_popped_child is not None: + scc_root[top] = min( + scc_root[top], + scc_root[last_popped_child]) + + for child in children: + if child not in visit_order: + # Recurse. + call_stack.append((top, children, child)) + call_stack.append((child, iter(graph[child]), None)) + break + if child in visiting: + scc_root[top] = min( + scc_root[top], + visit_order[child]) + else: + if scc_root[top] == visit_order[top]: + scc = [] + while visit_stack[-1] != top: + scc.append(visit_stack.pop()) + scc.append(visit_stack.pop()) + for item in scc: + visiting.remove(item) + sccs.append(scc) return sccs diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 851a8d2a2a048e4b0fc3b87ce9d730dc49ff70af..6b5c77de349dbc1ea594c57f0f9b868b76df62e4 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -511,9 +511,9 @@ def infer_unknown_types(kernel, expect_completion=False): from loopy.tools import compute_sccs # To speed up processing, we sort the variables by computing the SCCs of the - # type dependency graph. Each SCC represents a set of variables whose type - # mutually depends on themselves. The SCCs are returned in topological - # order. + # type dependency graph. Each SCC represents a set of variables whose types + # mutually depend on themselves. The SCCs are returned and processed in + # topological order. sccs = compute_sccs(dep_graph) item_lookup = _DictUnionView([ @@ -529,17 +529,18 @@ def infer_unknown_types(kernel, expect_completion=False): from loopy.kernel.data import TemporaryVariable, KernelArgument - failed_names = set() for var_chain in sccs: changed_during_last_queue_run = False queue = var_chain[:] + failed_names = set() while queue or changed_during_last_queue_run: if not queue and changed_during_last_queue_run: changed_during_last_queue_run = False - # Optimization: If there's a single variable in the SCC and - # the type of variable does not depend on itself, then - # the type is known after a single iteration. + # Optimization: If there's a single variable in the SCC without + # a self-referential dependency, then the type is known after a + # single iteration (we don't need to look at the expressions + # again). if len(var_chain) == 1: single_var, = var_chain if single_var not in dep_graph[single_var]: diff --git a/test/test_misc.py b/test/test_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6a2cd075e6dd69d5c92118032672f2a334a492 --- /dev/null +++ b/test/test_misc.py @@ -0,0 +1,79 @@ +from __future__ import division, absolute_import, print_function + +__copyright__ = "Copyright (C) 2016 Matt Wala" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import six # noqa +from six.moves import range + +import sys + +import logging +logger = logging.getLogger(__name__) + + +def test_compute_sccs(): + from loopy.tools import compute_sccs + import random + + rng = random.Random(0) + + def generate_random_graph(nnodes): + graph = dict((i, set()) for i in range(nnodes)) + for i in range(nnodes): + for j in range(nnodes): + # Edge probability 1/n: Generates decently interesting inputs. + if rng.randint(0, nnodes - 1) == 0: + graph[i].add(j) + return graph + + def verify_sccs(graph, sccs): + visited = set() + + def visit(node): + if node in visited: + return [] + else: + visited.add(node) + result = [] + for child in graph[node]: + result = result + visit(child) + return result + [node] + + for scc in sccs: + result = visit(scc[0]) + assert set(result) == set(scc), (set(result), set(scc)) + + for nnodes in range(10, 20): + for i in range(40): + graph = generate_random_graph(nnodes) + verify_sccs(graph, compute_sccs(graph)) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from py.test.cmdline import main + main([__file__]) + +# vim: foldmethod=marker