diff --git a/pymbolic/algorithm.py b/pymbolic/algorithm.py index 207df9d16751f491b8015c5e62e0649bcc4ab3b4..95d1b9673b8d88b6db5bb5900b91a7dbfa332847 100644 --- a/pymbolic/algorithm.py +++ b/pymbolic/algorithm.py @@ -27,6 +27,17 @@ def integer_power(x, n, one=1): def gcd(q, r): return extended_euclidean(q, r)[0] +def gcd_many(*args): + if len(args) == 0: + return 1 + elif len(args) == 1: + return args[0] + else: + return reduce(gcd, args) + +def lcm(q, r): + return abs(q*r)//gcd(q, r) + diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py index e13e56382b7be5c64437354c60a73d8f13fa97a3..b15900933b37d0461d7b4b39274b0c755cdfdacd 100644 --- a/pymbolic/mapper/unifier.py +++ b/pymbolic/mapper/unifier.py @@ -101,41 +101,41 @@ class UnifierBase(RecursiveMapper): return UnificationRecord([(lhs, rhs)]) - def map_constant(self, expr, other, unis): + def map_constant(self, expr, other, urecs): if expr == other: - return unis + return urecs else: return [] - def map_variable(self, expr, other, unis): + def map_variable(self, expr, other, urecs): new_uni_record = self.unification_record_from_equation( expr, other) if new_uni_record is None: if (isinstance(other, Variable) and other.name == expr.name and expr.name not in self.lhs_mapping_candidates): - return unis + return urecs else: return [] else: - return unify_many(unis, new_uni_record) + return unify_many(urecs, new_uni_record) - def map_subscript(self, expr, other, unis): + def map_subscript(self, expr, other, urecs): if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) + return self.treat_mismatch(expr, other, urecs) return self.rec(expr.aggregate, other.aggregate, - self.rec(expr.index, other.index, unis)) + self.rec(expr.index, other.index, urecs)) - def map_lookup(self, expr, other, unis): + def map_lookup(self, expr, other, urecs): if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) + return self.treat_mismatch(expr, other, urecs) if expr.name != other.name: return [] - return self.rec(expr.aggregate, other.aggregate, unis) + return self.rec(expr.aggregate, other.aggregate, urecs) - def map_sum(self, expr, other, unis): + def map_sum(self, expr, other, urecs): if (not isinstance(other, type(expr)) or len(expr.children) != len(other.children)): return [] @@ -145,7 +145,7 @@ class UnifierBase(RecursiveMapper): from pytools import generate_permutations had_structural_match = False for perm in generate_permutations(range(len(expr.children))): - it_assignments = unis + it_assignments = urecs for my_child, other_child in zip( expr.children, @@ -159,52 +159,52 @@ class UnifierBase(RecursiveMapper): result.extend(it_assignments) if not had_structural_match: - return self.treat_mismatch(expr, other, unis) + return self.treat_mismatch(expr, other, urecs) return result map_product = map_sum - def map_negation(self, expr, other, unis): + def map_negation(self, expr, other, urecs): if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) - return self.rec(expr.child, other.child, unis) + return self.treat_mismatch(expr, other, urecs) + return self.rec(expr.child, other.child, urecs) - def map_quotient(self, expr, other, unis): + def map_quotient(self, expr, other, urecs): if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) + return self.treat_mismatch(expr, other, urecs) return self.rec(expr.numerator, other.numerator, - self.rec(expr.denominator, other.denominator, unis)) + self.rec(expr.denominator, other.denominator, urecs)) map_floor_div = map_quotient map_remainder = map_quotient - def map_power(self, expr, other, unis): + def map_power(self, expr, other, urecs): if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) + return self.treat_mismatch(expr, other, urecs) return self.rec(expr.base, other.base, - self.rec(expr.exponent, other.exponent, unis)) + self.rec(expr.exponent, other.exponent, urecs)) - def map_list(self, expr, other, unis): + def map_list(self, expr, other, urecs): if (not isinstance(other, type(expr)) or len(expr) != len(other)): return [] for my_child, other_child in zip(expr, other): - unis = self.rec(my_child, other_child, unis) - if not unis: + urecs = self.rec(my_child, other_child, urecs) + if not urecs: break - return unis + return urecs map_tuple = map_list - def __call__(self, expr, other, unis=None): - if unis is None: - unis = [UnificationRecord([])] - return self.rec(expr, other, unis) + def __call__(self, expr, other, urecs=None): + if urecs is None: + urecs = [UnificationRecord([])] + return self.rec(expr, other, urecs) @@ -214,14 +214,69 @@ class UnidirectionalUnifier(UnifierBase): subexpression of the second. """ - def treat_mismatch(self, expr, other, unis): + def treat_mismatch(self, expr, other, urecs): return [] + def map_commut_assoc(self, expr, other, urecs, factory): + if not isinstance(other, type(expr)): + return + + plain_cand_variables = [] + non_var_children = [] + for child in expr.children: + if (isinstance(child, Variable) + and child.name in self.lhs_mapping_candidates): + plain_cand_variables.append(child) + else: + non_var_children.append(child) + # list (with indices matching non_var_children) of + # list of tuples (other_index, unifiers) + unification_candidates = [] -class BidirectionalUnifier(UnifierBase): - """Only assigns variables encountered in the first expression to - subexpression of the second. - """ + for i, my_child in enumerate(non_var_children): + i_matches = [] + for j, other_child in enumerate(other.children): + result = self.rec(my_child, other_child, urecs) + if result: + i_matches.append((j, result)) + + unification_candidates.append(i_matches) + + def match_children(urec, next_my_idx, other_leftovers): + if next_my_idx >= len(non_var_children): + if not plain_cand_variables and other_leftovers: + return + + eqns = [] + for pv in plain_cand_variables: + eqns.append((pv, factory( + other.children[i] for i in other_leftovers))) + other_leftovers = [] + + yield urec.unify(UnificationRecord(eqns)) + return + + for other_idx, pair_urecs in unification_candidates[next_my_idx]: + if other_idx not in other_leftovers: + continue + + new_urecs = unify_many(pair_urecs, urec) + new_rhs_leftovers = other_leftovers - set([other_idx]) + + for cand_urec in new_urecs: + for result_urec in match_children(cand_urec, next_my_idx+1, + new_rhs_leftovers): + yield result_urec + + for urec in match_children( + UnificationRecord([]), 0, set(range(len(other.children)))): + yield urec + + def map_sum(self, expr, other, unis): + from pymbolic.primitives import flattened_sum + return list(self.map_commut_assoc(expr, other, unis, flattened_sum)) - treat_mismatch = UnifierBase.map_variable + def map_product(self, expr, other, unis): + from pymbolic.primitives import flattened_product + return list(self.map_commut_assoc(expr, other, unis, flattened_product)) diff --git a/pymbolic/parser.py b/pymbolic/parser.py index 9a067df1d9072ed35ac87c0c0ffe9f5e6cfc9d49..bfc4f553fd7590feaf19d2b583454d2a4fe2c935 100644 --- a/pymbolic/parser.py +++ b/pymbolic/parser.py @@ -66,7 +66,10 @@ def parse(expr_str): def parse_expression(pstate, min_precedence=0): pstate.expect_not_end() - if pstate.is_next(_minus): + if pstate.is_next(_times): + pstate.advance() + left_exp = primitives.Wildcard() + elif pstate.is_next(_minus): pstate.advance() left_exp = -parse_expression(pstate, _PREC_UNARY_MINUS) elif pstate.is_next(_openpar): diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 889611efd41642c1f22c5a3d3abfa26901a55b5a..d658d6b0d78530c979257bbb7aefc905bc0127f0 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -259,6 +259,17 @@ class Variable(Leaf): +class Wildcard(Leaf): + def is_equal(self, other): + return (other.__class__ == self.__class__ + and self.name == other.name) + + def get_hash(self): + return hash((self.__class__, self.name)) + + mapper_method = intern("map_wildcard") + + class FunctionSymbol(AlgebraicLeaf): """Represents the name of a function.