diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 6b5488a20bc9d714fb5fde908b559ddebf4b9591..9a0a5b233c68304816297ae2b1cf7c2928b37f66 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -738,52 +738,83 @@ def find_idempotence(kernel): (insn.id, insn.read_dependency_names() & var_names) for insn in kernel.instructions) - non_idempotently_updated_vars = set() - - # FIXME: This can be made more efficient by simply starting - # from all written variables and not even considering - # instructions as the start of the first pass. + dep_graph = {} - new_insns = [] for insn in kernel.instructions: - all_my_var_writers = set() - for var in reads_map[insn.id]: - var_writers = writer_map.get(var, set()) - all_my_var_writers |= var_writers - - # {{{ find dependency loops, flag boostability - - while True: - last_all_my_var_writers = all_my_var_writers - - for writer_insn_id in last_all_my_var_writers: - for var in reads_map[writer_insn_id]: - all_my_var_writers = \ - all_my_var_writers | writer_map.get(var, set()) + dep_graph[insn.id] = set(writer_id + for var in reads_map[insn.id] + for writer_id in writer_map.get(var, set())) + + # {{{ find SCCs of dep_graph + + def dfs(graph, root_node=None, exclude=frozenset()): + postorder = [] + visited = set() + have_root_node = root_node is not None + to_search = set([root_node] if have_root_node else graph.keys()) + + while to_search: + stack = [next(iter(to_search))] + visiting = set() + + while stack: + top = stack[-1] + if top in visiting: + visiting.discard(top) + postorder.append(top) + to_search.discard(top) + if top in visited: + stack.pop() + else: + visiting.add(top) + visited.add(top) + stack.extend( + item for item in graph[top] + if item not in visited and item not in exclude) + + return postorder + + inv_dep_graph = dict((insn.id, set()) for insn in kernel.instructions) + for key, vals in six.iteritems(dep_graph): + for val in vals: + inv_dep_graph[val].add(key) + + postorder = dfs(inv_dep_graph) + + sccs = {} + exclude = set() + + for item in reversed(postorder): + if item in sccs: + continue + scc = dfs(dep_graph, root_node=item, exclude=exclude) + exclude.update(scc) + for scc_item in scc: + sccs[scc_item] = scc - if last_all_my_var_writers == all_my_var_writers: - break + # }}} - # }}} + non_idempotently_updated_vars = set() - boostable = insn.id not in all_my_var_writers + new_insns = [] + for insn in kernel.instructions: + boostable = len(sccs[insn.id]) == 1 and insn.id not in dep_graph[insn.id] if not boostable: non_idempotently_updated_vars.update( insn.assignee_var_names()) - insn = insn.copy(boostable=boostable) - - new_insns.append(insn) + new_insns.append(insn.copy(boostable=boostable)) # {{{ remove boostability from isns that access non-idempotently updated vars new2_insns = [] for insn in new_insns: - accessed_vars = insn.dependency_names() - boostable = insn.boostable and not bool( - non_idempotently_updated_vars & accessed_vars) - new2_insns.append(insn.copy(boostable=boostable)) + if insn.boostable and bool( + non_idempotently_updated_vars & insn.dependency_names()): + new2_insns.append(insn.copy(boostable=False)) + else: + new2_insns.append(insn) # }}}