diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index 840b04da0c65ad4a425cee47ab88964caf8a424c..c1e9197b516de9081ad230b021c4c02a4597670c 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -48,7 +48,14 @@ class _SymbolGenerator(object): from collections import defaultdict self.base_to_count = defaultdict(lambda: 0) + def _normalize(self, base): + # Strip off any _N suffix, to avoid generating conflicting names. + import re + base = re.split("_\d+$", base)[0] + return base if base != "" else "expr" + def __call__(self, base="expr"): + base = self._normalize(base) count = self.base_to_count[base] def make_id_str(base, count): diff --git a/sumpy/cse.py b/sumpy/cse.py index ca82b971c018326d95890bdebfd17e43443849a0..350a2c8d97d6ef4729657c85248d412e4128f6a1 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -174,10 +174,11 @@ class FuncArgTracker(object): for arg in self.func_to_argset[func_i]: self.arg_to_funcset[arg].remove(func_i) - def gen_common_arg_candidates(self, argset, min_func_i, threshold=2): + def get_common_arg_candidates(self, argset, min_func_i, threshold=2): """ - Generate the list of functions which have at least `threshold` arguments in - common from `argset`. + Return a dict whose keys are function numbers. The entries of the dict are + the number of arguments said function has in common with `argset`. Entries + have at least `threshold` items in common. """ from collections import defaultdict count_map = defaultdict(lambda: 0) @@ -187,28 +188,28 @@ class FuncArgTracker(object): if func_i >= min_func_i: count_map[func_i] += 1 - count_map_keys_in_order = sorted( - key for key, val in count_map.items() - if val >= threshold) + return dict( + (k, v) for k, v in count_map.items() + if v >= threshold) - for item in count_map_keys_in_order: - yield item - - def gen_subset_candidates(self, argset, min_func_i): + def get_subset_candidates(self, argset, restrict_to_funcset=None): """ - Generate the list of functions whose set of arguments contains `argset`. + Return a set of functions each of which whose argument list contains + `argset`, optionally filtered only to contain functions in + `restrict_to_funcset`. """ iarg = iter(argset) indices = set( - fi for fi in self.arg_to_funcset[next(iarg)] - if fi >= min_func_i) + fi for fi in self.arg_to_funcset[next(iarg)]) + + if restrict_to_funcset is not None: + indices &= restrict_to_funcset for arg in iarg: indices &= self.arg_to_funcset[arg] - for item in indices: - yield item + return indices def update_func_argset(self, func_i, new_argset): """ @@ -246,13 +247,28 @@ def match_common_args(func_class, funcs, opt_subs): :arg funcs: A list of function calls :arg opt_subs: A dictionary of substitutions which this function may update """ + + # Sort to ensure that whole-function subexpressions come before the items + # that use them. + funcs = sorted(funcs, key=lambda f: len(f.args)) arg_tracker = FuncArgTracker(funcs) changed = set() + from sumpy.tools import OrderedSet + for i in range(len(funcs)): - for j in arg_tracker.gen_common_arg_candidates( - arg_tracker.func_to_argset[i], i + 1, threshold=2): + common_arg_candidates = arg_tracker.get_common_arg_candidates( + arg_tracker.func_to_argset[i], i + 1, threshold=2) + + # Sort the candidates in order of match size. + # This makes us try combining smaller matches first. + common_arg_candidates = OrderedSet(sorted( + common_arg_candidates.keys(), + key=lambda k: (common_arg_candidates[k], k))) + + while common_arg_candidates: + j = common_arg_candidates.pop(last=False) com_args = arg_tracker.func_to_argset[i].intersection( arg_tracker.func_to_argset[j]) @@ -262,21 +278,36 @@ def match_common_args(func_class, funcs, opt_subs): # combined in a previous iteration. continue - com_func = func_class(*arg_tracker.get_args_in_value_order(com_args)) - com_func_number = arg_tracker.get_or_add_value_number(com_func) - # For all sets, replace the common symbols by the function # over them, to allow recursive matches. diff_i = arg_tracker.func_to_argset[i].difference(com_args) - arg_tracker.update_func_argset(i, diff_i | set([com_func_number])) - changed.add(i) + if diff_i: + # com_func needs to be unevaluated to allow for recursive matches. + com_func = func_class( + *arg_tracker.get_args_in_value_order(com_args), + evaluate=False) + com_func_number = arg_tracker.get_or_add_value_number(com_func) + arg_tracker.update_func_argset(i, diff_i | set([com_func_number])) + changed.add(i) + else: + # Treat the whole expression as a CSE. + # + # The reason this needs to be done is somewhat subtle. Within + # tree_cse(), to_eliminate only contains expressions that are + # seen more than once. The problem is unevaluated expressions + # do not compare equal to the evaluated equivalent. So + # tree_cse() won't mark funcs[i] as a CSE if we use an + # unevaluated version. + com_func = funcs[i] + com_func_number = arg_tracker.get_or_add_value_number(funcs[i]) diff_j = arg_tracker.func_to_argset[j].difference(com_args) arg_tracker.update_func_argset(j, diff_j | set([com_func_number])) changed.add(j) - for k in arg_tracker.gen_subset_candidates(com_args, j + 1): + for k in arg_tracker.get_subset_candidates( + com_args, common_arg_candidates): diff_k = arg_tracker.func_to_argset[k].difference(com_args) arg_tracker.update_func_argset(k, diff_k | set([com_func_number])) changed.add(k) @@ -299,15 +330,15 @@ def opt_cse(exprs): """ opt_subs = dict() - adds = [] - muls = [] + from sumpy.tools import OrderedSet + adds = OrderedSet() + muls = OrderedSet() seen_subexp = set() # {{{ look for optimization opportunities, clean up minus signs def find_opts(expr): - if not isinstance(expr, Basic): return @@ -335,10 +366,10 @@ def opt_cse(exprs): expr = neg_expr if isinstance(expr, Mul): - muls.append(expr) + muls.add(expr) elif isinstance(expr, Add): - adds.append(expr) + adds.add(expr) elif isinstance(expr, Pow): base, exp = expr.args @@ -351,8 +382,9 @@ def opt_cse(exprs): if isinstance(e, Basic): find_opts(e) - match_common_args(Add, adds, opt_subs) - match_common_args(Mul, muls, opt_subs) + match_common_args(Add, list(adds), opt_subs) + match_common_args(Mul, list(muls), opt_subs) + return opt_subs # }}} diff --git a/sumpy/tools.py b/sumpy/tools.py index 5e0d4341c136c1594194e84e4478f613e576dbb5..60687510c66cd76aaca87af035464c858133d766 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -279,6 +279,76 @@ class KernelComputation(object): # }}} +# {{{ OrderedSet + +# Source: http://code.activestate.com/recipes/576694-orderedset/ +# Author: Raymond Hettinger +# License: MIT + +import collections + + +class OrderedSet(collections.MutableSet): + + def __init__(self, iterable=None): + self.end = end = [] + end += [None, end, end] # sentinel node for doubly linked list + self.map = {} # key --> [key, prev, next] + if iterable is not None: + self |= iterable + + def __len__(self): + return len(self.map) + + def __contains__(self, key): + return key in self.map + + def add(self, key): + if key not in self.map: + end = self.end + curr = end[1] + curr[2] = end[1] = self.map[key] = [key, curr, end] + + def discard(self, key): + if key in self.map: + key, prev, next = self.map.pop(key) + prev[2] = next + next[1] = prev + + def __iter__(self): + end = self.end + curr = end[2] + while curr is not end: + yield curr[0] + curr = curr[2] + + def __reversed__(self): + end = self.end + curr = end[1] + while curr is not end: + yield curr[0] + curr = curr[1] + + def pop(self, last=True): + if not self: + raise KeyError('set is empty') + key = self.end[1][0] if last else self.end[2][0] + self.discard(key) + return key + + def __repr__(self): + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, list(self)) + + def __eq__(self, other): + if isinstance(other, OrderedSet): + return len(self) == len(other) and list(self) == list(other) + return set(self) == set(other) + +# }}} + + class KernelCacheWrapper(object): @memoize_method def get_cached_optimized_kernel(self, **kwargs): diff --git a/test/test_codegen.py b/test/test_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..59fe0340caceb7eaf281f7d725d0eb7343c82bcc --- /dev/null +++ b/test/test_codegen.py @@ -0,0 +1,59 @@ +from __future__ import division, absolute_import, print_function + +__copyright__ = "Copyright (C) 2017 Matt Wala" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +import sys + +import logging +logger = logging.getLogger(__name__) + + +def test_symbolic_assignment_name_uniqueness(): + # https://gitlab.tiker.net/inducer/sumpy/issues/13 + from sumpy.assignment_collection import SymbolicAssignmentCollection + + sac = SymbolicAssignmentCollection({"s_0": 1}) + sac.assign_unique("s_", 1) + sac.assign_unique("s_", 1) + assert len(sac.assignments) == 3 + + sac = SymbolicAssignmentCollection() + sac.assign_unique("s_0", 1) + sac.assign_unique("s_", 1) + sac.assign_unique("s_", 1) + + assert len(sac.assignments) == 3 + + +# You can test individual routines by typing +# $ python test_fmm.py 'test_sumpy_fmm(cl.create_some_context)' + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from py.test.cmdline import main + main([__file__]) + +# vim: fdm=marker diff --git a/test/test_cse.py b/test/test_cse.py index b5697ac4fa2bf67a0f5ed6f8eafa64f9b379059a..d2904176e486ee16bb92bae9c8702f76969a4bae 100644 --- a/test/test_cse.py +++ b/test/test_cse.py @@ -253,7 +253,7 @@ def test_issue_4499(): sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S(1)/2, z/2, -b + 1, -2*a + b, -2*a)) # noqa c = cse(t) - assert len(c[0]) == 13 + assert len(c[0]) == 11 def test_issue_6169(): @@ -345,6 +345,21 @@ def test_issue_7840(): assert len(substitutions) < 1 +def test_recursive_matching(): + assert cse([x+y, 2+x+y, x+y+z, 3+x+y+z]) == \ + ([(x0, x + y), (x1, x0 + z)], [x0, x0 + 2, x1, x1 + 3]) + assert cse(reversed([x+y, 2+x+y, x+y+z, 3+x+y+z])) == \ + ([(x0, x + y), (x1, x0 + z)], [x1 + 3, x1, x0 + 2, x0]) + # sympy 1.0 gives ([(x0, x*y*z)], [5*x0, w*(x*y), 3*x0]) + assert cse([x*y*z*5, x*y*w, x*y*z*3]) == \ + ([(x0, x*y), (x1, x0*z)], [5*x1, w*x0, 3*x1]) + # sympy 1.0 gives ([(x4, x*y*z)], [5*x4, w*x3*x4, 3*x*x0*x1*x2*y]) + assert cse([x*y*z*5, x*y*z*w*x3, x*y*3*x0*x1*x2]) == \ + ([(x4, x*y), (x5, x4*z)], [5*x5, w*x3*x5, 3*x0*x1*x2*x4]) + assert cse([2*x*x, x*x*y, x*x*y*w, x*x*y*w*x0, x*x*y*w*x2]) == \ + ([(x1, x**2), (x3, x1*y), (x4, w*x3)], [2*x1, x3, x4, x0*x4, x2*x4]) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])