From 8621651e46788250a458beb13301cc936d948330 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Sat, 28 Jan 2017 17:06:10 -0600 Subject: [PATCH] Move kill_trivial_assignments() over to pymbolic. When doing substitution, SymEngine seems to collapse Subs nodes. This was leading to problems for sumpy, which relies on having Subs nodes in the expressions after killing trivial assignments. This change moves the job to pymbolic, which appears to preserve Subs nodes. --- sumpy/codegen.py | 143 +++++++++++++++++++++++++++++++++++++++---- sumpy/e2e.py | 9 +-- sumpy/e2p.py | 11 ++-- sumpy/p2e.py | 9 +-- sumpy/p2p.py | 1 + sumpy/qbx.py | 10 +-- sumpy/symbolic.py | 59 +----------------- sumpy/tools.py | 1 + test/test_codegen.py | 22 ++++--- 9 files changed, 161 insertions(+), 104 deletions(-) diff --git a/sumpy/codegen.py b/sumpy/codegen.py index 59cc9071..32492edc 100644 --- a/sumpy/codegen.py +++ b/sumpy/codegen.py @@ -29,6 +29,7 @@ import pyopencl as cl import pyopencl.tools # noqa import loopy as lp +import six import re from pymbolic.mapper import IdentityMapper, WalkMapper, CSECachingMapperMixin @@ -38,8 +39,7 @@ from loopy.types import NumpyType from pytools import memoize_method -from sumpy.symbolic import (USE_SYMENGINE, - SympyToPymbolicMapper as SympyToPymbolicMapperBase) +from sumpy.symbolic import (SympyToPymbolicMapper as SympyToPymbolicMapperBase) import logging logger = logging.getLogger(__name__) @@ -64,7 +64,8 @@ _SPECIAL_FUNCTION_NAMES = frozenset(dir(sym.functions)) class SympyToPymbolicMapper(SympyToPymbolicMapperBase): def __init__(self, assignments): - self.assignments = assignments + self.assignments = dict( + (sym.Symbol(name), value) for name, value in assignments) self.derivative_cse_names = set() def map_Derivative(self, expr): # noqa @@ -85,13 +86,7 @@ class SympyToPymbolicMapper(SympyToPymbolicMapperBase): if isinstance(expr, int): return expr elif getattr(expr, "is_Function", False): - if USE_SYMENGINE: - try: - func_name = expr.get_name() - except AttributeError: - func_name = type(expr).__name__ - else: - func_name = type(expr).__name__ + func_name = SympyToPymbolicMapperBase.function_name(self, expr) # SymEngine capitalizes the names of the special functions. if func_name.lower() in _SPECIAL_FUNCTION_NAMES: func_name = func_name.lower() @@ -103,6 +98,130 @@ class SympyToPymbolicMapper(SympyToPymbolicMapperBase): # }}} +# {{{ trivial assignment elimination + +def make_one_step_subst(assignments): + assignments = dict(assignments) + unwanted_vars = set(six.iterkeys(assignments)) + + # Ensure no re-assignments. + assert len(unwanted_vars) == len(assignments) + + from loopy.symbolic import get_dependencies + unwanted_deps = dict( + (name, get_dependencies(value) & unwanted_vars) + for name, value in six.iteritems(assignments)) + + # {{{ compute substitution order + + toposort = [] + visited = set() + visiting = set() + + while unwanted_vars: + stack = [unwanted_vars.pop()] + + while stack: + top = stack[-1] + + if top in visiting: + visiting.remove(top) + toposort.append(top) + + if top in visited: + stack.pop() + continue + + visited.add(top) + visiting.add(top) + + for dep in unwanted_deps[top]: + # Check for no cycles. + assert dep not in visiting + stack.append(dep) + + # }}} + + # {{{ make substitution + + from pymbolic import substitute + + result = {} + + for name in toposort: + value = assignments[name] + value = substitute(value, result) + + result[name] = value + + # }}} + + # {{{ simplify substitution + + used_names = set(dep + for value in six.itervalues(result) + for dep in get_dependencies(value)) + + used_name_to_var = dict( + (used_name, prim.Variable(used_name)) for used_name in used_names) + + from pymbolic import evaluate + from functools import partial + simplify = partial(evaluate, context=used_name_to_var) + + for name, value in six.iteritems(result): + result[name] = simplify(value) + + # }}} + + return result + + +def is_assignment_nontrivial(name, value): + if prim.is_constant(value): + return False + elif isinstance(value, prim.Variable): + return False + elif (isinstance(value, prim.Product) + and len(value.children) == 2 + and sum(1 for arg in value.children if prim.is_constant(arg)) == 1 + and sum(1 for arg in value.children + if isinstance(arg, prim.Variable)) == 1): + # const*var: not good enough + return False + + return True + + +def kill_trivial_assignments(assignments, retain_names=set()): + logger.info("kill trivial assignments (plain): start") + approved_assignments = [] + rejected_assignments = [] + + for name, value in assignments: + if name in retain_names or is_assignment_nontrivial(name, value): + approved_assignments.append((name, value)) + else: + rejected_assignments.append((name, value)) + + # un-substitute rejected assignments + unsubst_rej = make_one_step_subst(rejected_assignments) + + result = [] + from pymbolic import substitute + for name, expr in approved_assignments: + r = substitute(expr, unsubst_rej) + result.append((name, r)) + + logger.info( + "kill trivial assignments (plain): done, {nrej} assignments killed" + .format(nrej=len(rejected_assignments))) + + return result + +# }}} + + # {{{ bessel handling BESSEL_PREAMBLE = """//CL// @@ -560,7 +679,7 @@ class MathConstantRewriter(CSECachingMapperMixin, IdentityMapper): def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[], - complex_dtype=None): + complex_dtype=None, retain_names=set()): logger.info("loopy instruction generation: start") assignments = list(assignments) @@ -572,6 +691,8 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[], if name not in sympy_conv.derivative_cse_names ] + assignments = kill_trivial_assignments(assignments, retain_names) + bdr = BesselDerivativeReplacer() assignments = [(name, bdr(expr)) for name, expr in assignments] diff --git a/sumpy/e2e.py b/sumpy/e2e.py index f93a8682..b8091140 100644 --- a/sumpy/e2e.py +++ b/sumpy/e2e.py @@ -93,17 +93,12 @@ class E2EBase(KernelCacheWrapper): sac.run_global_cse() - from sumpy.symbolic import kill_trivial_assignments - assignments = kill_trivial_assignments([ - (name, expr) - for name, expr in six.iteritems(sac.assignments)], - retain_names=tgt_coeff_names) - from sumpy.codegen import to_loopy_insns return to_loopy_insns( - assignments, + six.iteritems(sac.assignments), vector_names=set(["d"]), pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()], + retain_names=tgt_coeff_names, complex_dtype=np.complex128 # FIXME ) diff --git a/sumpy/e2p.py b/sumpy/e2p.py index 118f1743..6a77a5cc 100644 --- a/sumpy/e2p.py +++ b/sumpy/e2p.py @@ -83,6 +83,7 @@ class E2PBase(KernelCacheWrapper): coeff_exprs = [sym.Symbol("coeff%d" % i) for i in range(len(self.expansion.get_coefficient_identifiers()))] value = self.expansion.evaluate(coeff_exprs, bvec) + result_names = [ sac.assign_unique("result_%d_p" % i, knl.postprocess_at_target(value, bvec)) @@ -91,16 +92,12 @@ class E2PBase(KernelCacheWrapper): sac.run_global_cse() - from sumpy.symbolic import kill_trivial_assignments - assignments = kill_trivial_assignments([ - (name, expr) - for name, expr in six.iteritems(sac.assignments)], - retain_names=result_names) - from sumpy.codegen import to_loopy_insns - loopy_insns = to_loopy_insns(assignments, + loopy_insns = to_loopy_insns( + six.iteritems(sac.assignments), vector_names=set(["b"]), pymbolic_expr_maps=[self.expansion.get_code_transformer()], + retain_names=result_names, complex_dtype=np.complex128 # FIXME ) diff --git a/sumpy/p2e.py b/sumpy/p2e.py index 4f687ac5..e85213d5 100644 --- a/sumpy/p2e.py +++ b/sumpy/p2e.py @@ -83,17 +83,12 @@ class P2EBase(KernelCacheWrapper): sac.run_global_cse() - from sumpy.symbolic import kill_trivial_assignments - assignments = kill_trivial_assignments([ - (name, expr) - for name, expr in six.iteritems(sac.assignments)], - retain_names=coeff_names) - from sumpy.codegen import to_loopy_insns return to_loopy_insns( - assignments, + six.iteritems(sac.assignments), vector_names=set(["a"]), pymbolic_expr_maps=[self.expansion.get_code_transformer()], + retain_names=coeff_names, complex_dtype=np.complex128 # FIXME ) diff --git a/sumpy/p2p.py b/sumpy/p2p.py index a8b92b91..39a90205 100644 --- a/sumpy/p2p.py +++ b/sumpy/p2p.py @@ -90,6 +90,7 @@ class P2PBase(KernelComputation, KernelCacheWrapper): vector_names=set(["d"]), pymbolic_expr_maps=[ knl.get_code_transformer() for knl in self.kernels], + retain_names=result_names, complex_dtype=np.complex128 # FIXME ) diff --git a/sumpy/qbx.py b/sumpy/qbx.py index 5012e1a6..8892739e 100644 --- a/sumpy/qbx.py +++ b/sumpy/qbx.py @@ -134,17 +134,13 @@ class LayerPotentialBase(KernelComputation, KernelCacheWrapper): sac.run_global_cse() - from sumpy.symbolic import kill_trivial_assignments - assignments = kill_trivial_assignments([ - (name, expr.subs("tau", 0)) - for name, expr in six.iteritems(sac.assignments)], - retain_names=result_names) - from sumpy.codegen import to_loopy_insns - loopy_insns = to_loopy_insns(assignments, + loopy_insns = to_loopy_insns( + six.iteritems(sac.assignments), vector_names=set(["a", "b"]), pymbolic_expr_maps=[ expn.kernel.get_code_transformer() for expn in self.expansions], + retain_names=result_names, complex_dtype=np.complex128 # FIXME ) diff --git a/sumpy/symbolic.py b/sumpy/symbolic.py index 565fa8c5..0fcda1fb 100644 --- a/sumpy/symbolic.py +++ b/sumpy/symbolic.py @@ -76,7 +76,7 @@ _find_symbolic_backend() # Before adding a function here, make sure it's present in both modules. SYMBOLIC_API = """ Add Basic Mul Pow exp sqrt symbols sympify cos sin atan2 Function Symbol -Integer Matrix Subs I pi functions""".split() +Derivative Integer Matrix Subs I pi functions""".split() if USE_SYMENGINE: from symengine import sympy_compat as sym @@ -101,63 +101,6 @@ def _coeff_isneg(a): return False -# {{{ trivial assignment elimination - -def make_one_step_subst(assignments): - unwanted_vars = set(sym.Symbol(name) for name, value in assignments) - - result = {} - assignments = dict((sym.Symbol(name), value) for name, value in assignments) - for name, value in assignments.items(): - while value.atoms(sym.Symbol) & unwanted_vars: - value = value.subs(assignments) - - result[name] = value - - return result - - -def is_assignment_nontrivial(name, value): - if value.is_Number: - return False - elif isinstance(value, sym.Symbol): - return False - elif (isinstance(value, sym.Mul) - and len(value.args) == 2 - and sum(1 for arg in value.args if arg.is_Number) == 1 - and sum(1 for arg in value.args if isinstance(arg, sym.Symbol)) == 1): - # const*var: not good enough - return False - - return True - - -def kill_trivial_assignments(assignments, retain_names=set()): - logger.info("kill trivial assignments (plain): start") - approved_assignments = [] - rejected_assignments = [] - - for name, value in assignments: - if name in retain_names or is_assignment_nontrivial(name, value): - approved_assignments.append((name, value)) - else: - rejected_assignments.append((name, value)) - - # un-substitute rejected assignments - unsubst_rej = make_one_step_subst(rejected_assignments) - - result = [] - for name, expr in approved_assignments: - r = expr.xreplace(unsubst_rej) - result.append((name, r)) - - logger.info("kill trivial assignments (plain): done") - - return result - -# }}} - - # {{{ debugging of sympy CSE via Maxima class _DerivativeKiller(IdentityMapperBase): diff --git a/sumpy/tools.py b/sumpy/tools.py index 7d0738a7..4c952e06 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -74,6 +74,7 @@ class MiDerivativeTaker(object): for next_deriv, next_mi in self.get_derivative_taking_sequence( current_mi, mi): expr = expr.diff(next_deriv) + print("taking diff", expr) self.cache_by_mi[next_mi] = expr return expr diff --git a/test/test_codegen.py b/test/test_codegen.py index fcb1dd58..376a8fcf 100644 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -30,27 +30,35 @@ logger = logging.getLogger(__name__) def test_kill_trivial_assignments(): - from sumpy.symbolic import kill_trivial_assignments, symbols, sympify - x, y, nt = symbols("x, y, nt") - t0, t1, t2 = symbols("t0:3") - u0, u1, u2 = symbols("u0:3") + from pymbolic import var + x, y, t0, t1, t2 = [var(s) for s in "x y t0 t1 t2".split()] assignments = ( - ("t0", sympify(6)), + ("t0", 6), ("t1", -t0), ("t2", 6*x), ("nt", x**y), # users of trivial assignments ("u0", t0 + 1), ("u1", t1 + 1), - ("u2", t2 + 1) + ("u2", t2 + 1), ) + from sumpy.codegen import kill_trivial_assignments result = kill_trivial_assignments( assignments, retain_names=("u0", "u1", "u2")) - assert result == [('nt', x**y), ('u0', 7), ('u1', -5), ('u2', 1 + 6*x)] + from pymbolic.primitives import Sum + + def _s(*vals): + return Sum(vals) + + assert result == [ + ('nt', x**y), + ('u0', _s(6, 1)), + ('u1', _s(-6, 1)), + ('u2', _s(6*x, 1))] # You can test individual routines by typing -- GitLab