diff --git a/src/mapper/substitutor.py b/src/mapper/substitutor.py index 2aaffc8756faddfbf293f7b3af3ac523a2908550..ceed26718c37b15a8a54ddc2497a33649f2cdcd3 100644 --- a/src/mapper/substitutor.py +++ b/src/mapper/substitutor.py @@ -3,7 +3,7 @@ import pymbolic.mapper -class SubstitutionMapper(pymbolic.mapper.IdentityMapper): +class SubstitutionMapperBase(object): def __init__(self, variable_assignments): self.assignments = variable_assignments @@ -28,11 +28,19 @@ class SubstitutionMapper(pymbolic.mapper.IdentityMapper): -def substitute(expression, variable_assignments = {}): +class SubstitutionMapper(SubstitutionMapperBase, + pymbolic.mapper.IdentityMapper): + pass + + + + +def substitute(expression, variable_assignments = {}, + mapper_class=SubstitutionMapper): import pymbolic.primitives as primitives new_var_ass = {} for k, v in variable_assignments.iteritems(): new_var_ass[primitives.make_variable(k)] = v - return SubstitutionMapper(new_var_ass)(expression) + return mapper_class(new_var_ass)(expression)