Skip to content
Snippets Groups Projects
Commit 3b7bdf6c authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Unifier: Add force_var_match.

parent 748fb495
No related branches found
No related tags found
No related merge requests found
...@@ -73,21 +73,29 @@ def unify_many(unis1, uni2): ...@@ -73,21 +73,29 @@ def unify_many(unis1, uni2):
class UnifierBase(RecursiveMapper): class UnifierBase(RecursiveMapper):
def __init__(self, lhs_mapping_candidates=None, def __init__(self, lhs_mapping_candidates=None,
rhs_mapping_candidates=None): rhs_mapping_candidates=None,
force_var_match=True):
self.lhs_mapping_candidates = lhs_mapping_candidates self.lhs_mapping_candidates = lhs_mapping_candidates
self.rhs_mapping_candidates = rhs_mapping_candidates self.rhs_mapping_candidates = rhs_mapping_candidates
self.force_var_match = force_var_match
def unification_record_from_equation(self, lhs, rhs): def unification_record_from_equation(self, lhs, rhs):
if isinstance(lhs, (tuple, list)) or isinstance(rhs, (tuple, list)): if isinstance(lhs, (tuple, list)) or isinstance(rhs, (tuple, list)):
# must match elementwise! # must match elementwise!
return None return None
lhs_is_var = isinstance(lhs, Variable)
rhs_is_var = isinstance(rhs, Variable)
if self.force_var_match and not (lhs_is_var or rhs_is_var):
return None
if (self.lhs_mapping_candidates is not None if (self.lhs_mapping_candidates is not None
and isinstance(lhs, Variable) and lhs_is_var
and lhs.name not in self.lhs_mapping_candidates): and lhs.name not in self.lhs_mapping_candidates):
return None return None
if (self.rhs_mapping_candidates is not None if (self.rhs_mapping_candidates is not None
and isinstance(rhs, Variable) and rhs_is_var
and rhs.name not in self.rhs_mapping_candidates): and rhs.name not in self.rhs_mapping_candidates):
return None return None
...@@ -122,17 +130,46 @@ class UnifierBase(RecursiveMapper): ...@@ -122,17 +130,46 @@ class UnifierBase(RecursiveMapper):
def map_lookup(self, expr, other, unis): def map_lookup(self, expr, other, unis):
if not isinstance(other, type(expr)): if not isinstance(other, type(expr)):
return self.treat_mismatch(expr, other, unis) return self.treat_mismatch(expr, other, unis)
if self.name != other.name: if expr.name != other.name:
return [] return []
return self.rec(expr.aggregate, other.aggregate, unis) return self.rec(expr.aggregate, other.aggregate, unis)
def map_sum(self, expr, other, unis):
if (not isinstance(other, type(expr))
or len(expr.children) != len(other.children)):
return []
result = []
from pytools import generate_permutations
had_structural_match = False
for perm in generate_permutations(range(len(expr.children))):
it_assignments = unis
for my_child, other_child in zip(
expr.children,
(other.children[i] for i in perm)):
it_assignments = self.rec(my_child, other_child, it_assignments)
if not it_assignments:
break
if it_assignments:
had_structural_match = True
result.extend(it_assignments)
if not had_structural_match:
return self.treat_mismatch(expr, other, unis)
return result
map_product = map_sum
def map_negation(self, expr, other, unis): def map_negation(self, expr, other, unis):
if not isinstance(other, type(expr)): if not isinstance(other, type(expr)):
return self.treat_mismatch(expr, other, unis) return self.treat_mismatch(expr, other, unis)
return self.rec(expr.child, other.child, unis) return self.rec(expr.child, other.child, unis)
def map_quotient(self, expr, other, unis): def map_quotient(self, expr, other, unis):
if not isinstance(other, type(expr)): if not isinstance(other, type(expr)):
return self.treat_mismatch(expr, other, unis) return self.treat_mismatch(expr, other, unis)
...@@ -162,7 +199,6 @@ class UnifierBase(RecursiveMapper): ...@@ -162,7 +199,6 @@ class UnifierBase(RecursiveMapper):
return unis return unis
map_tuple = map_list map_tuple = map_list
def __call__(self, expr, other, unis=None): def __call__(self, expr, other, unis=None):
...@@ -178,36 +214,6 @@ class UnidirectionalUnifier(UnifierBase): ...@@ -178,36 +214,6 @@ class UnidirectionalUnifier(UnifierBase):
subexpression of the second. subexpression of the second.
""" """
def map_sum(self, expr, other, unis):
if (not isinstance(other, type(expr))
or len(expr.children) != len(other.children)):
return []
result = []
from pytools import generate_permutations
had_structural_match = False
for perm in generate_permutations(range(len(expr.children))):
it_assignments = unis
for my_child, other_child in zip(
expr.children,
(other.children[i] for i in perm)):
it_assignments = self.rec(my_child, other_child, it_assignments)
if not it_assignments:
break
if it_assignments:
had_structural_match = True
result.extend(it_assignments)
if not had_structural_match:
return self.treat_mismatch(expr, other, unis)
return result
map_product = map_sum
def treat_mismatch(self, expr, other, unis): def treat_mismatch(self, expr, other, unis):
return [] return []
...@@ -219,5 +225,3 @@ class BidirectionalUnifier(UnifierBase): ...@@ -219,5 +225,3 @@ class BidirectionalUnifier(UnifierBase):
""" """
treat_mismatch = UnifierBase.map_variable treat_mismatch = UnifierBase.map_variable
map_sum = UnifierBase.map_variable
map_product = UnifierBase.map_variable
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment