Skip to content
Snippets Groups Projects
Commit bd350265 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fix sum/product unification.

parent 8f0ae973
No related branches found
No related tags found
No related merge requests found
......@@ -101,41 +101,41 @@ class UnifierBase(RecursiveMapper):
return UnificationRecord([(lhs, rhs)])
def map_constant(self, expr, other, unis):
def map_constant(self, expr, other, urecs):
if expr == other:
return unis
return urecs
else:
return []
def map_variable(self, expr, other, unis):
def map_variable(self, expr, other, urecs):
new_uni_record = self.unification_record_from_equation(
expr, other)
if new_uni_record is None:
if (isinstance(other, Variable)
and other.name == expr.name
and expr.name not in self.lhs_mapping_candidates):
return unis
return urecs
else:
return []
else:
return unify_many(unis, new_uni_record)
return unify_many(urecs, new_uni_record)
def map_subscript(self, expr, other, unis):
def map_subscript(self, expr, other, urecs):
if not isinstance(other, type(expr)):
return self.treat_mismatch(expr, other, unis)
return self.treat_mismatch(expr, other, urecs)
return self.rec(expr.aggregate, other.aggregate,
self.rec(expr.index, other.index, unis))
self.rec(expr.index, other.index, urecs))
def map_lookup(self, expr, other, unis):
def map_lookup(self, expr, other, urecs):
if not isinstance(other, type(expr)):
return self.treat_mismatch(expr, other, unis)
return self.treat_mismatch(expr, other, urecs)
if expr.name != other.name:
return []
return self.rec(expr.aggregate, other.aggregate, unis)
return self.rec(expr.aggregate, other.aggregate, urecs)
def map_sum(self, expr, other, unis):
def map_sum(self, expr, other, urecs):
if (not isinstance(other, type(expr))
or len(expr.children) != len(other.children)):
return []
......@@ -145,7 +145,7 @@ class UnifierBase(RecursiveMapper):
from pytools import generate_permutations
had_structural_match = False
for perm in generate_permutations(range(len(expr.children))):
it_assignments = unis
it_assignments = urecs
for my_child, other_child in zip(
expr.children,
......@@ -159,52 +159,52 @@ class UnifierBase(RecursiveMapper):
result.extend(it_assignments)
if not had_structural_match:
return self.treat_mismatch(expr, other, unis)
return self.treat_mismatch(expr, other, urecs)
return result
map_product = map_sum
def map_negation(self, expr, other, unis):
def map_negation(self, expr, other, urecs):
if not isinstance(other, type(expr)):
return self.treat_mismatch(expr, other, unis)
return self.rec(expr.child, other.child, unis)
return self.treat_mismatch(expr, other, urecs)
return self.rec(expr.child, other.child, urecs)
def map_quotient(self, expr, other, unis):
def map_quotient(self, expr, other, urecs):
if not isinstance(other, type(expr)):
return self.treat_mismatch(expr, other, unis)
return self.treat_mismatch(expr, other, urecs)
return self.rec(expr.numerator, other.numerator,
self.rec(expr.denominator, other.denominator, unis))
self.rec(expr.denominator, other.denominator, urecs))
map_floor_div = map_quotient
map_remainder = map_quotient
def map_power(self, expr, other, unis):
def map_power(self, expr, other, urecs):
if not isinstance(other, type(expr)):
return self.treat_mismatch(expr, other, unis)
return self.treat_mismatch(expr, other, urecs)
return self.rec(expr.base, other.base,
self.rec(expr.exponent, other.exponent, unis))
self.rec(expr.exponent, other.exponent, urecs))
def map_list(self, expr, other, unis):
def map_list(self, expr, other, urecs):
if (not isinstance(other, type(expr))
or len(expr) != len(other)):
return []
for my_child, other_child in zip(expr, other):
unis = self.rec(my_child, other_child, unis)
if not unis:
urecs = self.rec(my_child, other_child, urecs)
if not urecs:
break
return unis
return urecs
map_tuple = map_list
def __call__(self, expr, other, unis=None):
if unis is None:
unis = [UnificationRecord([])]
return self.rec(expr, other, unis)
def __call__(self, expr, other, urecs=None):
if urecs is None:
urecs = [UnificationRecord([])]
return self.rec(expr, other, urecs)
......@@ -214,14 +214,69 @@ class UnidirectionalUnifier(UnifierBase):
subexpression of the second.
"""
def treat_mismatch(self, expr, other, unis):
def treat_mismatch(self, expr, other, urecs):
return []
def map_commut_assoc(self, expr, other, urecs, factory):
if not isinstance(other, type(expr)):
return
plain_cand_variables = []
non_var_children = []
for child in expr.children:
if (isinstance(child, Variable)
and child.name in self.lhs_mapping_candidates):
plain_cand_variables.append(child)
else:
non_var_children.append(child)
# list (with indices matching non_var_children) of
# list of tuples (other_index, unifiers)
unification_candidates = []
class BidirectionalUnifier(UnifierBase):
"""Only assigns variables encountered in the first expression to
subexpression of the second.
"""
for i, my_child in enumerate(non_var_children):
i_matches = []
for j, other_child in enumerate(other.children):
result = self.rec(my_child, other_child, urecs)
if result:
i_matches.append((j, result))
unification_candidates.append(i_matches)
def match_children(urec, next_my_idx, other_leftovers):
if next_my_idx >= len(non_var_children):
if not plain_cand_variables and other_leftovers:
return
eqns = []
for pv in plain_cand_variables:
eqns.append((pv, factory(
other.children[i] for i in other_leftovers)))
other_leftovers = []
yield urec.unify(UnificationRecord(eqns))
return
for other_idx, pair_urecs in unification_candidates[next_my_idx]:
if other_idx not in other_leftovers:
continue
new_urecs = unify_many(pair_urecs, urec)
new_rhs_leftovers = other_leftovers - set([other_idx])
for cand_urec in new_urecs:
for result_urec in match_children(cand_urec, next_my_idx+1,
new_rhs_leftovers):
yield result_urec
for urec in match_children(
UnificationRecord([]), 0, set(range(len(other.children)))):
yield urec
def map_sum(self, expr, other, unis):
from pymbolic.primitives import flattened_sum
return list(self.map_commut_assoc(expr, other, unis, flattened_sum))
treat_mismatch = UnifierBase.map_variable
def map_product(self, expr, other, unis):
from pymbolic.primitives import flattened_product
return list(self.map_commut_assoc(expr, other, unis, flattened_product))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment