diff --git a/src/mapper/__init__.py b/src/mapper/__init__.py index f0fd178972a5b2d510dbc7aa1e6b1869be2823f8..efbca3d94bfab59d0df99d77e52fa409cbc0dc1b 100644 --- a/src/mapper/__init__.py +++ b/src/mapper/__init__.py @@ -18,6 +18,9 @@ class Mapper(object): else: return self.map_foreign(expr, *args, **kwargs) + def map_variable(self, expr, *args, **kwargs): + return self.map_algebraic_leaf(self, expr, *args, **kwargs) + def map_subscript(self, expr, *args, **kwargs): return self.map_algebraic_leaf(self, expr, *args, **kwargs) @@ -161,9 +164,10 @@ class IdentityMapperBase(object): -class IdentityMapper(RecursiveMapper, IdentityMapperBase): +class IdentityMapper(IdentityMapperBase, RecursiveMapper): def handle_unsupported_expression(self, expr, *args, **kwargs): return expr -class NonrecursiveIdentityMapper(Mapper, IdentityMapperBase): - pass +class NonrecursiveIdentityMapper(IdentityMapperBase, Mapper): + def handle_unsupported_expression(self, expr, *args, **kwargs): + return expr diff --git a/src/mapper/substitutor.py b/src/mapper/substitutor.py index 3329da3a58efd622b60d5910f6009447716d8e99..2aaffc8756faddfbf293f7b3af3ac523a2908550 100644 --- a/src/mapper/substitutor.py +++ b/src/mapper/substitutor.py @@ -5,23 +5,23 @@ import pymbolic.mapper class SubstitutionMapper(pymbolic.mapper.IdentityMapper): def __init__(self, variable_assignments): - self.Assignments = variable_assignments + self.assignments = variable_assignments def map_variable(self, expr): try: - return self.Assignments[expr] + return self.assignments[expr] except KeyError: return expr def map_subscript(self, expr): try: - return self.Assignments[expr] + return self.assignments[expr] except KeyError: return pymbolic.mapper.IdentityMapper.map_subscript(self, expr) def map_lookup(self, expr): try: - return self.Assignments[expr] + return self.assignments[expr] except KeyError: return pymbolic.mapper.IdentityMapper.map_lookup(self, expr) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 4b14c56e3012f0dcf471ec3acc88ab2f88479335..5c0aaf33407d21f100785290a261075ca608283c 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -9,8 +9,13 @@ class TestPymbolic(unittest.TestCase): x = var("x") u = (x+1)**5 - print expand(u) + expand(u) + def test_substitute(self): + from pymbolic import parse, substitute, evaluate + u = parse("5+x.min**2") + xmin = parse("x.min") + assert evaluate(substitute(u, {xmin:25})) == 630 if __name__ == '__main__': unittest.main()