From cc7c83295b03a35a73652394e9fa479e6640876e Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Thu, 15 Dec 2016 18:36:13 -0600 Subject: [PATCH] Type inference: Get a modest speedup by processing the variables in topological order, chunked according to SCCs of the dependency graph. See also #15. --- loopy/preprocess.py | 55 ++------------ loopy/tools.py | 58 ++++++++++++++ loopy/type_inference.py | 162 +++++++++++++++++++++++++--------------- 3 files changed, 165 insertions(+), 110 deletions(-) diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 9a0a5b233..2f19e7e3e 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -738,61 +738,20 @@ def find_idempotence(kernel): (insn.id, insn.read_dependency_names() & var_names) for insn in kernel.instructions) - dep_graph = {} + from collections import defaultdict + dep_graph = defaultdict(lambda: set()) for insn in kernel.instructions: dep_graph[insn.id] = set(writer_id for var in reads_map[insn.id] for writer_id in writer_map.get(var, set())) - # {{{ find SCCs of dep_graph - - def dfs(graph, root_node=None, exclude=frozenset()): - postorder = [] - visited = set() - have_root_node = root_node is not None - to_search = set([root_node] if have_root_node else graph.keys()) - - while to_search: - stack = [next(iter(to_search))] - visiting = set() - - while stack: - top = stack[-1] - if top in visiting: - visiting.discard(top) - postorder.append(top) - to_search.discard(top) - if top in visited: - stack.pop() - else: - visiting.add(top) - visited.add(top) - stack.extend( - item for item in graph[top] - if item not in visited and item not in exclude) - - return postorder - - inv_dep_graph = dict((insn.id, set()) for insn in kernel.instructions) - for key, vals in six.iteritems(dep_graph): - for val in vals: - inv_dep_graph[val].add(key) - - postorder = dfs(inv_dep_graph) - - sccs = {} - exclude = set() - - for item in reversed(postorder): - if item in sccs: - continue - scc = dfs(dep_graph, root_node=item, exclude=exclude) - exclude.update(scc) - for scc_item in scc: - sccs[scc_item] = scc + # Find SCCs of dep_graph. + from loopy.tools import compute_sccs - # }}} + sccs = dict((item, scc) + for scc in compute_sccs(dep_graph) + for item in scc) non_idempotently_updated_vars = set() diff --git a/loopy/tools.py b/loopy/tools.py index ae370d5aa..7952b875a 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -281,6 +281,64 @@ def empty_aligned(shape, dtype, order='C', n=64): # }}} +# {{{ compute SCCs with Tarjan's algorithm + +def compute_sccs(graph): + to_search = set(graph.keys()) + visit_order = {} + scc_root = {} + sccs = [] + + while to_search: + top = next(iter(to_search)) + stack = [top] + visiting = set() + + scc = [] + + while stack: + top = stack[-1] + + if top in visiting: + for child in graph[top]: + if child in visiting: + # Update SCC root. + scc_root[top] = min( + scc_root[top], + scc_root[child]) + + # Add to the current SCC and check if we're the root. + scc.append(top) + + if visit_order[top] == scc_root[top]: + sccs.append(scc) + scc = [] + + to_search.discard(top) + visiting.remove(top) + + if top in visit_order: + stack.pop() + else: + count = len(visit_order) + visit_order[top] = count + scc_root[top] = count + visiting.add(top) + + for child in graph[top]: + if child in visiting: + # Update SCC root. + scc_root[top] = min( + scc_root[top], + visit_order[child]) + elif child not in visit_order: + stack.append(child) + + return sccs + +# }}} + + def is_interned(s): return s is None or intern(s) is s diff --git a/loopy/type_inference.py b/loopy/type_inference.py index a31f011a0..a5ea0de7b 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -30,6 +30,8 @@ import numpy as np from loopy.tools import is_integer from loopy.types import NumpyType +from pytools import memoize_method + from loopy.diagnostic import ( LoopyError, TypeInferenceFailure, DependencyTypeInferenceFailure) @@ -202,6 +204,7 @@ class TypeInferenceMapper(CombineMapper): else: return self.combine([n_dtype_set, d_dtype_set]) + @memoize_method def map_constant(self, expr): if is_integer(expr): for tp in [np.int32, np.int64]: @@ -462,6 +465,9 @@ def infer_unknown_types(kernel, expect_completion=False): logger.debug("%s: infer types" % kernel.name) + import time + start_time = time.time() + def debug(s): logger.debug("%s: %s" % (kernel.name, s)) @@ -489,6 +495,26 @@ def infer_unknown_types(kernel, expect_completion=False): # }}} + logger.debug("finding types for {} names".format(len(names_for_type_inference))) + + writer_map = kernel.writer_map() + + dep_graph = dict( + (written_var, set( + read_var + for insn_id in writer_map.get(written_var, []) + for read_var in kernel.id_to_insn[insn_id].read_dependency_names() + if read_var in names_for_type_inference)) + for written_var in names_for_type_inference) + + from loopy.tools import compute_sccs + + # To speed up processing, we sort the variables by computing the SCCs of the + # type dependency graph. Each SCC represents a set of variables whose type + # mutually depends on themselves. The SCCs are returned in topological + # order. + sccs = compute_sccs(dep_graph) + item_lookup = _DictUnionView([ new_temp_vars, new_arg_dict @@ -502,75 +528,87 @@ def infer_unknown_types(kernel, expect_completion=False): from loopy.kernel.data import TemporaryVariable, KernelArgument - changed_during_last_queue_run = False - queue = names_for_type_inference[:] - failed_names = set() - while queue or changed_during_last_queue_run: - if not queue and changed_during_last_queue_run: - changed_during_last_queue_run = False - queue = names_for_type_inference[:] - - name = queue.pop(0) - item = item_lookup[name] - - debug("inferring type for %s %s" % (type(item).__name__, item.name)) - - result, symbols_with_unavailable_types = \ - _infer_var_type(kernel, item.name, type_inf_mapper, subst_expander) - - failed = not result - if not failed: - new_dtype, = result - debug(" success: %s" % new_dtype) - if new_dtype != item.dtype: - debug(" changed from: %s" % item.dtype) - changed_during_last_queue_run = True - - if isinstance(item, TemporaryVariable): - new_temp_vars[name] = item.copy(dtype=new_dtype) - elif isinstance(item, KernelArgument): - new_arg_dict[name] = item.copy(dtype=new_dtype) - else: - raise LoopyError("unexpected item type in type inference") - else: - debug(" failure") - - if failed: - if item.name in failed_names: - # this item has failed before, give up. - advice = "" - if symbols_with_unavailable_types: - advice += ( - " (need type of '%s'--check for missing arguments)" - % ", ".join(symbols_with_unavailable_types)) - - if expect_completion: - raise LoopyError( - "could not determine type of '%s'%s" - % (item.name, advice)) - - else: - # We're done here. - break + for var_chain in sccs: + changed_during_last_queue_run = False + queue = var_chain[:] + + while queue or changed_during_last_queue_run: + if not queue and changed_during_last_queue_run: + changed_during_last_queue_run = False + # Optimization: If there's a single variable in the SCC and + # the type of variable does not depend on itself, then + # the type is known after a single iteration. + if len(var_chain) == 1: + single_var, = var_chain + if single_var not in dep_graph[single_var]: + break + queue = var_chain[:] + + name = queue.pop(0) + item = item_lookup[name] + + debug("inferring type for %s %s" % (type(item).__name__, item.name)) + + result, symbols_with_unavailable_types = ( + _infer_var_type( + kernel, item.name, type_inf_mapper, subst_expander)) + + failed = not result + if not failed: + new_dtype, = result + debug(" success: %s" % new_dtype) + if new_dtype != item.dtype: + debug(" changed from: %s" % item.dtype) + changed_during_last_queue_run = True + + if isinstance(item, TemporaryVariable): + new_temp_vars[name] = item.copy(dtype=new_dtype) + elif isinstance(item, KernelArgument): + new_arg_dict[name] = item.copy(dtype=new_dtype) + else: + raise LoopyError("unexpected item type in type inference") + else: + debug(" failure") + + if failed: + if item.name in failed_names: + # this item has failed before, give up. + advice = "" + if symbols_with_unavailable_types: + advice += ( + " (need type of '%s'--check for missing arguments)" + % ", ".join(symbols_with_unavailable_types)) + + if expect_completion: + raise LoopyError( + "could not determine type of '%s'%s" + % (item.name, advice)) + + else: + # We're done here. + break - # remember that this item failed - failed_names.add(item.name) + # remember that this item failed + failed_names.add(item.name) - if set(queue) == failed_names: - # We did what we could... - print(queue, failed_names, item.name) - assert not expect_completion - break + if set(queue) == failed_names: + # We did what we could... + print(queue, failed_names, item.name) + assert not expect_completion + break - # can't infer type yet, put back into queue - queue.append(name) - else: - # we've made progress, reset failure markers - failed_names = set() + # can't infer type yet, put back into queue + queue.append(name) + else: + # we've made progress, reset failure markers + failed_names = set() # }}} + end_time = time.time() + logger.debug("type inference took {:.2f} seconds".format(end_time - start_time)) + return unexpanded_kernel.copy( temporary_variables=new_temp_vars, args=[new_arg_dict[arg.name] for arg in kernel.args], -- GitLab