diff --git a/src/mapper/substitutor.py b/src/mapper/substitutor.py index ceed26718c37b15a8a54ddc2497a33649f2cdcd3..0b9108ada961a9b97789c69c1eb2a60dc44d6432 100644 --- a/src/mapper/substitutor.py +++ b/src/mapper/substitutor.py @@ -3,44 +3,50 @@ import pymbolic.mapper -class SubstitutionMapperBase(object): - def __init__(self, variable_assignments): - self.assignments = variable_assignments +class SubstitutionMapper(pymbolic.mapper.IdentityMapper): + def __init__(self, subst_func): + self.subst_func = subst_func def map_variable(self, expr): - try: - return self.assignments[expr] - except KeyError: + result = self.subst_func(expr) + if result is not None: + return result + else: return expr def map_subscript(self, expr): - try: - return self.assignments[expr] - except KeyError: + result = self.subst_func(expr) + if result is not None: + return result + else: return pymbolic.mapper.IdentityMapper.map_subscript(self, expr) def map_lookup(self, expr): - try: - return self.assignments[expr] - except KeyError: + result = self.subst_func(expr) + if result is not None: + return result + else: return pymbolic.mapper.IdentityMapper.map_lookup(self, expr) -class SubstitutionMapper(SubstitutionMapperBase, - pymbolic.mapper.IdentityMapper): - pass - - - - -def substitute(expression, variable_assignments = {}, - mapper_class=SubstitutionMapper): +def substitute(expression, variable_assignments={}, **kwargs): import pymbolic.primitives as primitives - new_var_ass = {} - for k, v in variable_assignments.iteritems(): - new_var_ass[primitives.make_variable(k)] = v + variable_assignments = variable_assignments.copy() + variable_assignments.update(kwargs) - return mapper_class(new_var_ass)(expression) + def subst_func(var): + try: + return variable_assignments[var] + except KeyError: + if isinstance(var, primitives.Variable): + try: + return variable_assignments[var.name] + except KeyError: + return None + else: + return None + + return SubstitutionMapper(subst_func)(expression)