Skip to content
Snippets Groups Projects
Commit b5824207 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fix inheritance order on IdentityMapper. Add test. (+)

Mapper: Variable is an AlgebraicLeaf.
Fix NonrecursiveIdentityMapper.
parent a7b736b9
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,9 @@ class Mapper(object):
else:
return self.map_foreign(expr, *args, **kwargs)
def map_variable(self, expr, *args, **kwargs):
return self.map_algebraic_leaf(self, expr, *args, **kwargs)
def map_subscript(self, expr, *args, **kwargs):
return self.map_algebraic_leaf(self, expr, *args, **kwargs)
......@@ -161,9 +164,10 @@ class IdentityMapperBase(object):
class IdentityMapper(RecursiveMapper, IdentityMapperBase):
class IdentityMapper(IdentityMapperBase, RecursiveMapper):
def handle_unsupported_expression(self, expr, *args, **kwargs):
return expr
class NonrecursiveIdentityMapper(Mapper, IdentityMapperBase):
pass
class NonrecursiveIdentityMapper(IdentityMapperBase, Mapper):
def handle_unsupported_expression(self, expr, *args, **kwargs):
return expr
......@@ -5,23 +5,23 @@ import pymbolic.mapper
class SubstitutionMapper(pymbolic.mapper.IdentityMapper):
def __init__(self, variable_assignments):
self.Assignments = variable_assignments
self.assignments = variable_assignments
def map_variable(self, expr):
try:
return self.Assignments[expr]
return self.assignments[expr]
except KeyError:
return expr
def map_subscript(self, expr):
try:
return self.Assignments[expr]
return self.assignments[expr]
except KeyError:
return pymbolic.mapper.IdentityMapper.map_subscript(self, expr)
def map_lookup(self, expr):
try:
return self.Assignments[expr]
return self.assignments[expr]
except KeyError:
return pymbolic.mapper.IdentityMapper.map_lookup(self, expr)
......
......@@ -9,8 +9,13 @@ class TestPymbolic(unittest.TestCase):
x = var("x")
u = (x+1)**5
print expand(u)
expand(u)
def test_substitute(self):
from pymbolic import parse, substitute, evaluate
u = parse("5+x.min**2")
xmin = parse("x.min")
assert evaluate(substitute(u, {xmin:25})) == 630
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment