diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 9a0a5b233c68304816297ae2b1cf7c2928b37f66..db7792cce55c9b7851850c3059e821fc574c3270 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -738,61 +738,21 @@ def find_idempotence(kernel): (insn.id, insn.read_dependency_names() & var_names) for insn in kernel.instructions) - dep_graph = {} + from collections import defaultdict + dep_graph = defaultdict(lambda: set()) for insn in kernel.instructions: dep_graph[insn.id] = set(writer_id for var in reads_map[insn.id] for writer_id in writer_map.get(var, set())) - # {{{ find SCCs of dep_graph - - def dfs(graph, root_node=None, exclude=frozenset()): - postorder = [] - visited = set() - have_root_node = root_node is not None - to_search = set([root_node] if have_root_node else graph.keys()) - - while to_search: - stack = [next(iter(to_search))] - visiting = set() - - while stack: - top = stack[-1] - if top in visiting: - visiting.discard(top) - postorder.append(top) - to_search.discard(top) - if top in visited: - stack.pop() - else: - visiting.add(top) - visited.add(top) - stack.extend( - item for item in graph[top] - if item not in visited and item not in exclude) - - return postorder - - inv_dep_graph = dict((insn.id, set()) for insn in kernel.instructions) - for key, vals in six.iteritems(dep_graph): - for val in vals: - inv_dep_graph[val].add(key) - - postorder = dfs(inv_dep_graph) - - sccs = {} - exclude = set() - - for item in reversed(postorder): - if item in sccs: - continue - scc = dfs(dep_graph, root_node=item, exclude=exclude) - exclude.update(scc) - for scc_item in scc: - sccs[scc_item] = scc + # Find SCCs of dep_graph. These are used for checking if the instruction is + # in a dependency cycle. + from loopy.tools import compute_sccs - # }}} + sccs = dict((item, scc) + for scc in compute_sccs(dep_graph) + for item in scc) non_idempotently_updated_vars = set() diff --git a/loopy/tools.py b/loopy/tools.py index ae370d5aaac9ff75f530e1d0951a2f904b686e42..01d0641fc25c11a092185125604613819a0293ca 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -281,6 +281,65 @@ def empty_aligned(shape, dtype, order='C', n=64): # }}} +# {{{ compute SCCs with Tarjan's algorithm + +def compute_sccs(graph): + to_search = set(graph.keys()) + visit_order = {} + scc_root = {} + sccs = [] + + while to_search: + top = next(iter(to_search)) + call_stack = [(top, iter(graph[top]), None)] + visit_stack = [] + visiting = set() + + scc = [] + + while call_stack: + top, children, last_popped_child = call_stack.pop() + + if top not in visiting: + # Unvisited: mark as visited, initialize SCC root. + count = len(visit_order) + visit_stack.append(top) + visit_order[top] = count + scc_root[top] = count + visiting.add(top) + to_search.discard(top) + + # Returned from a recursion, update SCC. + if last_popped_child is not None: + scc_root[top] = min( + scc_root[top], + scc_root[last_popped_child]) + + for child in children: + if child not in visit_order: + # Recurse. + call_stack.append((top, children, child)) + call_stack.append((child, iter(graph[child]), None)) + break + if child in visiting: + scc_root[top] = min( + scc_root[top], + visit_order[child]) + else: + if scc_root[top] == visit_order[top]: + scc = [] + while visit_stack[-1] != top: + scc.append(visit_stack.pop()) + scc.append(visit_stack.pop()) + for item in scc: + visiting.remove(item) + sccs.append(scc) + + return sccs + +# }}} + + def is_interned(s): return s is None or intern(s) is s diff --git a/loopy/type_inference.py b/loopy/type_inference.py index a31f011a0ce8e5403b54984eb45db0970a8370b0..99a16bfc23341dba3d28c71038681c31d3e00dba 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -112,32 +112,28 @@ class TypeInferenceMapper(CombineMapper): 0 <= len(dtype_set) <= 1 for dtype_set in dtype_sets) - if not all( - isinstance(dtype, NumpyType) + from pytools import is_single_valued + + dtypes = [dtype for dtype_set in dtype_sets - for dtype in dtype_set): - from pytools import is_single_valued, single_valued - if not is_single_valued( - dtype - for dtype_set in dtype_sets - for dtype in dtype_set): + for dtype in dtype_set] + + if not all(isinstance(dtype, NumpyType) for dtype in dtypes): + if not is_single_valued(dtypes): raise TypeInferenceFailure( "Nothing known about operations between '%s'" - % ", ".join(str(dtype) - for dtype_set in dtype_sets - for dtype in dtype_set)) + % ", ".join(str(dtype) for dtype in dtypes)) - return single_valued(dtype - for dtype_set in dtype_sets - for dtype in dtype_set) + return [dtypes[0]] - numpy_dtypes = [dtype.dtype - for dtype_set in dtype_sets - for dtype in dtype_set] + numpy_dtypes = [dtype.dtype for dtype in dtypes] if not numpy_dtypes: return [] + if is_single_valued(numpy_dtypes): + return [dtypes[0]] + result = numpy_dtypes.pop() while numpy_dtypes: other = numpy_dtypes.pop() @@ -179,7 +175,6 @@ class TypeInferenceMapper(CombineMapper): else: dtype_sets.append(dtype_set) - from pytools import all if all(dtype.is_integral() for dtype_set in dtype_sets for dtype in dtype_set): @@ -462,6 +457,9 @@ def infer_unknown_types(kernel, expect_completion=False): logger.debug("%s: infer types" % kernel.name) + import time + start_time = time.time() + def debug(s): logger.debug("%s: %s" % (kernel.name, s)) @@ -489,6 +487,27 @@ def infer_unknown_types(kernel, expect_completion=False): # }}} + logger.debug("finding types for {count:d} names".format( + count=len(names_for_type_inference))) + + writer_map = kernel.writer_map() + + dep_graph = dict( + (written_var, set( + read_var + for insn_id in writer_map.get(written_var, []) + for read_var in kernel.id_to_insn[insn_id].read_dependency_names() + if read_var in names_for_type_inference)) + for written_var in names_for_type_inference) + + from loopy.tools import compute_sccs + + # To speed up processing, we sort the variables by computing the SCCs of the + # type dependency graph. Each SCC represents a set of variables whose types + # mutually depend on themselves. The SCCs are returned and processed in + # topological order. + sccs = compute_sccs(dep_graph) + item_lookup = _DictUnionView([ new_temp_vars, new_arg_dict @@ -502,75 +521,89 @@ def infer_unknown_types(kernel, expect_completion=False): from loopy.kernel.data import TemporaryVariable, KernelArgument - changed_during_last_queue_run = False - queue = names_for_type_inference[:] - - failed_names = set() - while queue or changed_during_last_queue_run: - if not queue and changed_during_last_queue_run: - changed_during_last_queue_run = False - queue = names_for_type_inference[:] - - name = queue.pop(0) - item = item_lookup[name] - - debug("inferring type for %s %s" % (type(item).__name__, item.name)) - - result, symbols_with_unavailable_types = \ - _infer_var_type(kernel, item.name, type_inf_mapper, subst_expander) - - failed = not result - if not failed: - new_dtype, = result - debug(" success: %s" % new_dtype) - if new_dtype != item.dtype: - debug(" changed from: %s" % item.dtype) - changed_during_last_queue_run = True + for var_chain in sccs: + changed_during_last_queue_run = False + queue = var_chain[:] + failed_names = set() + + while queue or changed_during_last_queue_run: + if not queue and changed_during_last_queue_run: + changed_during_last_queue_run = False + # Optimization: If there's a single variable in the SCC without + # a self-referential dependency, then the type is known after a + # single iteration (we don't need to look at the expressions + # again). + if len(var_chain) == 1: + single_var, = var_chain + if single_var not in dep_graph[single_var]: + break + queue = var_chain[:] + + name = queue.pop(0) + item = item_lookup[name] + + debug("inferring type for %s %s" % (type(item).__name__, item.name)) + + result, symbols_with_unavailable_types = ( + _infer_var_type( + kernel, item.name, type_inf_mapper, subst_expander)) + + failed = not result + if not failed: + new_dtype, = result + debug(" success: %s" % new_dtype) + if new_dtype != item.dtype: + debug(" changed from: %s" % item.dtype) + changed_during_last_queue_run = True + + if isinstance(item, TemporaryVariable): + new_temp_vars[name] = item.copy(dtype=new_dtype) + elif isinstance(item, KernelArgument): + new_arg_dict[name] = item.copy(dtype=new_dtype) + else: + raise LoopyError("unexpected item type in type inference") + else: + debug(" failure") + + if failed: + if item.name in failed_names: + # this item has failed before, give up. + advice = "" + if symbols_with_unavailable_types: + advice += ( + " (need type of '%s'--check for missing arguments)" + % ", ".join(symbols_with_unavailable_types)) + + if expect_completion: + raise LoopyError( + "could not determine type of '%s'%s" + % (item.name, advice)) + + else: + # We're done here. + break - if isinstance(item, TemporaryVariable): - new_temp_vars[name] = item.copy(dtype=new_dtype) - elif isinstance(item, KernelArgument): - new_arg_dict[name] = item.copy(dtype=new_dtype) - else: - raise LoopyError("unexpected item type in type inference") - else: - debug(" failure") - - if failed: - if item.name in failed_names: - # this item has failed before, give up. - advice = "" - if symbols_with_unavailable_types: - advice += ( - " (need type of '%s'--check for missing arguments)" - % ", ".join(symbols_with_unavailable_types)) - - if expect_completion: - raise LoopyError( - "could not determine type of '%s'%s" - % (item.name, advice)) + # remember that this item failed + failed_names.add(item.name) - else: - # We're done here. + if set(queue) == failed_names: + # We did what we could... + print(queue, failed_names, item.name) + assert not expect_completion break - # remember that this item failed - failed_names.add(item.name) - - if set(queue) == failed_names: - # We did what we could... - print(queue, failed_names, item.name) - assert not expect_completion - break - - # can't infer type yet, put back into queue - queue.append(name) - else: - # we've made progress, reset failure markers - failed_names = set() + # can't infer type yet, put back into queue + queue.append(name) + else: + # we've made progress, reset failure markers + failed_names = set() # }}} + end_time = time.time() + logger.debug("type inference took {dur:.2f} seconds".format( + dur=end_time - start_time)) + return unexpanded_kernel.copy( temporary_variables=new_temp_vars, args=[new_arg_dict[arg.name] for arg in kernel.args], diff --git a/test/test_loopy.py b/test/test_loopy.py index e41d55b85e504bcd39db37bd888ddbedbf6122f4..6b607109678c0b280113707ee77c0ede7df8f72d 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -102,6 +102,28 @@ def test_type_inference_no_artificial_doubles(ctx_factory): assert "double" not in code +def test_type_inference_with_type_dependencies(): + knl = lp.make_kernel( + "{[i]: i=0}", + """ + <>a = 99 + a = a + 1 + <>b = 0 + <>c = 1 + b = b + c + 1.0 + c = b + c + <>d = b + 2 + 1j + """, + "...") + knl = lp.infer_unknown_types(knl) + + from loopy.types import to_loopy_type + assert knl.temporary_variables["a"].dtype == to_loopy_type(np.int32) + assert knl.temporary_variables["b"].dtype == to_loopy_type(np.float32) + assert knl.temporary_variables["c"].dtype == to_loopy_type(np.float32) + assert knl.temporary_variables["d"].dtype == to_loopy_type(np.complex128) + + def test_sized_and_complex_literals(ctx_factory): ctx = ctx_factory() diff --git a/test/test_misc.py b/test/test_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..384c1326b75850f8c43c50914934f7dc5b097404 --- /dev/null +++ b/test/test_misc.py @@ -0,0 +1,79 @@ +from __future__ import division, absolute_import, print_function + +__copyright__ = "Copyright (C) 2016 Matt Wala" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import six # noqa +from six.moves import range + +import sys + +import logging +logger = logging.getLogger(__name__) + + +def test_compute_sccs(): + from loopy.tools import compute_sccs + import random + + rng = random.Random(0) + + def generate_random_graph(nnodes): + graph = dict((i, set()) for i in range(nnodes)) + for i in range(nnodes): + for j in range(nnodes): + # Edge probability 2/n: Generates decently interesting inputs. + if rng.randint(0, nnodes - 1) <= 1: + graph[i].add(j) + return graph + + def verify_sccs(graph, sccs): + visited = set() + + def visit(node): + if node in visited: + return [] + else: + visited.add(node) + result = [] + for child in graph[node]: + result = result + visit(child) + return result + [node] + + for scc in sccs: + result = visit(scc[0]) + assert set(result) == set(scc), (set(result), set(scc)) + + for nnodes in range(10, 20): + for i in range(40): + graph = generate_random_graph(nnodes) + verify_sccs(graph, compute_sccs(graph)) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from py.test.cmdline import main + main([__file__]) + +# vim: foldmethod=marker