diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py index b15900933b37d0461d7b4b39274b0c755cdfdacd..b5c14cc05fc55557cbb3478c6a09273a46ac3e77 100644 --- a/pymbolic/mapper/unifier.py +++ b/pymbolic/mapper/unifier.py @@ -72,16 +72,36 @@ def unify_many(unis1, uni2): class UnifierBase(RecursiveMapper): + # The idea of the algorithm here is that the unifier accumulates a set of + # unification possibilities (:class:`UnificationRecord`) as it descends the + # expression tree. :func:`unify_many` above then checks if these possibilities + # are consistent with new incoming information (also encoded as a + # :class:`UnificationRecord`) and either augments or abandons them. + def __init__(self, lhs_mapping_candidates=None, rhs_mapping_candidates=None, force_var_match=True): + """ + :arg lhs_mapping_candidates: list or set of variable names that may be + assigned in the left-hand ("first") expression + :arg rhs_mapping_candidates: list or set of variable names that may be + assigned in the right-hand ("second") expression + :arg force_var_match: In the (unimplemented) case of bidirectional + unification, only assign to variable names, don't make matches + between higher-level expressions. + """ + self.lhs_mapping_candidates = lhs_mapping_candidates self.rhs_mapping_candidates = rhs_mapping_candidates self.force_var_match = force_var_match def unification_record_from_equation(self, lhs, rhs): if isinstance(lhs, (tuple, list)) or isinstance(rhs, (tuple, list)): - # must match elementwise! + # Always force lists/tuples to agree elementwise, never + # generate a unification record between them directly. + # This pushes the matching process down to the elementwise + # level. + return None lhs_is_var = isinstance(lhs, Variable) @@ -110,7 +130,9 @@ class UnifierBase(RecursiveMapper): def map_variable(self, expr, other, urecs): new_uni_record = self.unification_record_from_equation( expr, other) + if new_uni_record is None: + # Check if the variables match literally--that's ok, too. if (isinstance(other, Variable) and other.name == expr.name and expr.name not in self.lhs_mapping_candidates): @@ -210,8 +232,9 @@ class UnifierBase(RecursiveMapper): class UnidirectionalUnifier(UnifierBase): - """Only assigns variables encountered in the first expression to - subexpression of the second. + """Finds assignments of variables encountered in the + first ("left") expression to subexpression of the second + ("right") expression. """ def treat_mismatch(self, expr, other, urecs):