Skip to content
Snippets Groups Projects
Commit b370c2e4 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Support tuples in expressions in a few places, to be able to compile tuple generation.

parent 1b5bc881
No related branches found
No related tags found
No related merge requests found
...@@ -24,7 +24,7 @@ def _constant_mapper(c): ...@@ -24,7 +24,7 @@ def _constant_mapper(c):
class CompileMapper(StringifyMapper): class CompileMapper(StringifyMapper):
def __init__(self): def __init__(self):
StringifyMapper.__init__(self, StringifyMapper.__init__(self,
constant_mapper=_constant_mapper) constant_mapper=_constant_mapper)
def map_polynomial(self, expr, enclosing_prec): def map_polynomial(self, expr, enclosing_prec):
...@@ -46,7 +46,7 @@ class CompileMapper(StringifyMapper): ...@@ -46,7 +46,7 @@ class CompileMapper(StringifyMapper):
next_exp = rev_data[i+1][0] next_exp = rev_data[i+1][0]
else: else:
next_exp = 0 next_exp = 0
result = "(%s+%s)%s" % (result, self(coeff, PREC_SUM), result = "(%s+%s)%s" % (result, self(coeff, PREC_SUM),
stringify_exp(exp-next_exp)) stringify_exp(exp-next_exp))
#print "A", result #print "A", result
#print "B", expr #print "B", expr
...@@ -68,16 +68,23 @@ class CompileMapper(StringifyMapper): ...@@ -68,16 +68,23 @@ class CompileMapper(StringifyMapper):
return "numpy.array(%s)" % stringify_leading_dimension(expr) return "numpy.array(%s)" % stringify_leading_dimension(expr)
def map_foreign(self, expr, enclosing_prec):
if isinstance(expr, tuple):
return "(%s)" % (", ".join(self.rec(child, PREC_NONE) for child in expr))
else:
return StringifyMapper.map_foreign(self, expr, enclosing_prec)
class CompiledExpression: class CompiledExpression:
"""This class encapsulates a compiled expression. """This class encapsulates a compiled expression.
The main reason for its existence is the fact that a dynamically-constructed The main reason for its existence is the fact that a dynamically-constructed
lambda function is not picklable. lambda function is not picklable.
""" """
def __init__(self, expression, variables = []): def __init__(self, expression, variables = []):
import pymbolic.primitives as primi import pymbolic.primitives as primi
...@@ -105,13 +112,13 @@ class CompiledExpression: ...@@ -105,13 +112,13 @@ class CompiledExpression:
all_variables = self._Variables + used_variables all_variables = self._Variables + used_variables
expr_s = CompileMapper()(self._Expression, PREC_NONE) expr_s = CompileMapper()(self._Expression, PREC_NONE)
func_s = "lambda %s:%s" % (",".join(str(v) for v in all_variables), func_s = "lambda %s: %s" % (",".join(str(v) for v in all_variables),
expr_s) expr_s)
self.__call__ = eval(func_s, ctx) self.__call__ = eval(func_s, ctx)
def __getinitargs__(self): def __getinitargs__(self):
return self._Expression, self._Variables return self._Expression, self._Variables
def __getstate__(self): def __getstate__(self):
return None return None
......
...@@ -142,6 +142,11 @@ class CombineMapper(RecursiveMapper): ...@@ -142,6 +142,11 @@ class CombineMapper(RecursiveMapper):
self.rec(expr.then), self.rec(expr.then),
self.rec(expr.else_)]) self.rec(expr.else_)])
def map_foreign(self, expr, *args):
if isinstance(expr, tuple):
return self.combine([self.rec(child) for child in expr])
else:
return RecursiveMapper.map_foreign(self, expr, *args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment