Skip to content
Snippets Groups Projects
Commit 60aae072 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Remove unnecessary property functions. Incremental improvements.

Remove is_constant from differentiator.
Trivial fixes.
parent b24dac1f
No related branches found
No related tags found
No related merge requests found
......@@ -24,13 +24,13 @@ class DependencyMapper(CombineMapper):
def map_variable(self, expr):
return set([expr])
def map_call(self, expr, *args, **kwargs):
def map_call(self, expr):
if self.IncludeCalls:
return set([expr])
else:
return CombineMapper.map_call(self, expr)
def map_lookup(self, expr, *args, **kwargs):
def map_lookup(self, expr):
if self.IncludeLookups:
return set([expr])
else:
......
......@@ -34,25 +34,25 @@ def map_math_functions_by_name(i, func, pars):
class DifferentiationMapper(pymbolic.mapper.RecursiveMapper):
def __init__(self, variable, func_map):
self.Variable = variable
self.FunctionMap = func_map
def __init__(self, variable, func_map=map_math_functions_by_name):
self.variable = variable
self.function_map = func_map
def map_constant(self, expr):
return 0
def map_variable(self, expr):
if expr == self.Variable:
if expr == self.variable:
return 1
else:
return 0
def map_call(self, expr):
return pymbolic.sum(
self.FunctionMap(i, expr.function, expr.parameters)
return pymbolic.flattened_sum(
self.function_map(i, expr.function, expr.parameters)
* self.rec(par)
for i, par in enumerate(expr.parameters)
if not self._isc(par))
)
map_subscript = map_variable
......@@ -60,29 +60,27 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper):
return -self.rec(expr.child)
def map_sum(self, expr):
return pymbolic.sum(self.rec(child) for child in expr.children
if not self._isc(child))
return pymbolic.flattened_sum(self.rec(child) for child in expr.children)
def map_product(self, expr):
return pymbolic.sum(
pymbolic.product(
return pymbolic.flattened_sum(
pymbolic.flattened_product(
expr.children[0:i] +
(self.rec(child),) +
expr.children[i+1:])
for i, child in enumerate(expr.children)
if not self._isc(child))
for i, child in enumerate(expr.children))
def map_quotient(self, expr):
f = expr.numerator
g = expr.denominator
f_const = self._isc(f)
g_const = self._isc(g)
df = self.rec(f)
dg = self.rec(g)
if f_const and g_const:
if (not df) and (not dg):
return 0
elif f_const:
elif (not df):
return -f*self.rec(g)/g**2
elif g_const:
elif (not dg):
return self.rec(f)/g
else:
return (self.rec(f)*g-self.rec(g)*f)/g**2
......@@ -90,16 +88,16 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper):
def map_power(self, expr):
f = expr.base
g = expr.exponent
f_const = self._isc(f)
g_const = self._isc(g)
df = self.rec(f)
dg = self.rec(g)
log = pymbolic.var("log")
if f_const and g_const:
if (not df) and (not dg):
return 0
elif f_const:
elif (not df):
return log(f) * f**g * self.rec(g)
elif g_const:
elif (not dg):
return g * f**(g-1) * self.rec(f)
else:
return log(f) * f**g * self.rec(g) + \
......@@ -124,14 +122,6 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper):
return \
Polynomial(expr.base, tuple(deriv_coeff), expr.unit) + \
Polynomial(expr.base, tuple(deriv_base), expr.unit)
def _isc(self,subexp):
return pymbolic.is_constant(subexp, [self.Variable],
include_lookups=True, include_subscripts=True
)
......
......@@ -74,6 +74,12 @@ class EvaluationMapper(RecursiveMapper):
class FloatEvaluationMapper(EvaluationMapper):
def handle_unsupported_expression(self, expr):
try:
return float(expr)
except:
raise TypeError, "cannot convert %s to float" % type(expr)
def map_constant(self, expr):
return float(expr)
......
......@@ -154,31 +154,38 @@ class Expression(object):
return "%s(%s)" % (self.__class__.__name__, initargs_str)
class AlgebraicLeaf(Expression):
pass
class Leaf(AlgebraicLeaf):
pass
class Variable(Leaf):
def __init__(self, name):
self._Name = name
self.name = name
def __getinitargs__(self):
return self._Name,
def _name(self):
return self._Name
name = property(_name)
def __lt__(self, other):
if isinstance(other, Variable):
return self._Name.__lt__(other._Name)
return self.name.__lt__(other.name)
else:
return NotImplemented
def __eq__(self, other):
return isinstance(other, Variable) and self._Name == other._Name
return (isinstance(other, Variable)
and self.name == other.name)
def __hash__(self):
return 0x111 ^ hash(self.name)
......@@ -188,24 +195,16 @@ class Variable(Leaf):
class Call(AlgebraicLeaf):
def __init__(self, function, parameters):
self._Function = function
self._Parameters = parameters
self.function = function
self.parameters = parameters
def __getinitargs__(self):
return self._Function, self._Parameters
def _function(self):
return self._Function
function = property(_function)
def _parameters(self):
return self._Parameters
parameters = property(_parameters)
return self.function, self.parameters
def __eq__(self, other):
return isinstance(other, Call) \
and (self._Function == other._Function) \
and (self._Parameters == other._Parameters)
and (self.function == other.function) \
and (self.parameters == other.parameters)
def __hash__(self):
return hash(self.function) ^ hash(self.parameters)
......@@ -213,26 +212,21 @@ class Call(AlgebraicLeaf):
def get_mapper_method(self, mapper):
return mapper.map_call
class Subscript(AlgebraicLeaf):
def __init__(self, aggregate, index):
self._Aggregate = aggregate
self._Index = index
self.aggregate = aggregate
self.index = index
def __getinitargs__(self):
return self._Aggregate, self._Index
def _aggregate(self):
return self._Aggregate
aggregate = property(_aggregate)
def _index(self):
return self._Index
index = property(_index)
def __eq__(self, other):
return isinstance(other, Subscript) \
and (self._Aggregate == other._Aggregate) \
and (self._Index == other._Index)
and (self.aggregate == other.aggregate) \
and (self.index == other.index)
def __hash__(self):
return 0x123 ^ hash(self.aggregate) ^ hash(self.index)
......@@ -242,26 +236,19 @@ class Subscript(AlgebraicLeaf):
class ElementLookup(AlgebraicLeaf):
def __init__(self, aggregate, name):
self._Aggregate = aggregate
self._Name = name
self.aggregate = aggregate
self.name = name
def __getinitargs__(self):
return self._Aggregate, self._Name
def _aggregate(self):
return self._Aggregate
aggregate = property(_aggregate)
def _name(self):
return self._Name
name = property(_name)
def __eq__(self, other):
return isinstance(other, ElementLookup) \
and (self._Aggregate == other._Aggregate) \
and (self._Name == other._Name)
and (self.aggregate == other.aggregate) \
and (self.name == other.name)
def __hash__(self):
return 0x183 ^ hash(self.aggregate) ^ hash(self.name)
......@@ -269,41 +256,41 @@ class ElementLookup(AlgebraicLeaf):
def get_mapper_method(self, mapper):
return mapper.map_lookup
class Sum(Expression):
def __init__(self, children):
assert isinstance(children, tuple)
self._Children = children
self.children = children
def __getinitargs__(self):
return self._Children
def _children(self):
return self._Children
children = property(_children)
return self.children
def __eq__(self, other):
return isinstance(other, Sum) and (self._Children == other._Children)
return (isinstance(other, Sum)
and (set(self.children) == set(other.children)))
def __add__(self, other):
if not is_valid_operand(other):
return NotImplemented
if isinstance(other, Sum):
return Sum(self._Children + other._Children)
return Sum(self.children + other.children)
if not other:
return self
return Sum(self._Children + (other,))
return Sum(self.children + (other,))
def __radd__(self, other):
if not is_constant(other):
return NotImplemented
if isinstance(other, Sum):
return Sum(other._Children + self._Children)
return Sum(other.children + self.children)
if not other:
return self
return Sum((other,) + self._Children)
return Sum((other,) + self.children)
def __sub__(self, other):
if not is_valid_operand(other):
......@@ -311,13 +298,13 @@ class Sum(Expression):
if not other:
return self
return Sum(self._Children + (-other,))
return Sum(self.children + (-other,))
def __nonzero__(self):
if len(self._Children) == 0:
if len(self.children) == 0:
return True
elif len(self._Children) == 1:
return bool(self._Children[0])
elif len(self.children) == 1:
return bool(self.children[0])
else:
# FIXME: Right semantics?
return True
......@@ -328,45 +315,45 @@ class Sum(Expression):
def get_mapper_method(self, mapper):
return mapper.map_sum
class Product(Expression):
def __init__(self, children):
assert isinstance(children, tuple)
self._Children = children
self.children = children
def __getinitargs__(self):
return self._Children
def _children(self):
return self._Children
children = property(_children)
return self.children
def __eq__(self, other):
return isinstance(other, Product) and (self._Children == other._Children)
return (isinstance(other, Product)
and (set(self.children) == set(other.children)))
def __mul__(self, other):
if not is_valid_operand(other):
return NotImplemented
if isinstance(other, Product):
return Product(self._Children + other._Children)
return Product(self.children + other.children)
if not other:
return 0
if not other-1:
return self
return Product(self._Children + (other,))
return Product(self.children + (other,))
def __rmul__(self, other):
if not is_constant(other):
return NotImplemented
if isinstance(other, Product):
return Product(other._Children + self._Children)
return Product(other.children + self.children)
if not other:
return 0
if not other-1:
return self
return Product((other,) + self._Children)
return Product((other,) + self.children)
def __nonzero__(self):
for i in self._Children:
for i in self.children:
if not i:
return False
return True
......@@ -377,30 +364,33 @@ class Product(Expression):
def get_mapper_method(self, mapper):
return mapper.map_product
class Quotient(Expression):
def __init__(self, numerator, denominator=1):
self._Numerator = numerator
self._Denominator = denominator
self.numerator = numerator
self.denominator = denominator
def __getinitargs__(self):
return self._Numerator, self._Denominator
return self.numerator, self.denominator
def _num(self):
return self._Numerator
numerator=property(_num)
@property
def num(self):
return self.numerator
def _den(self):
return self._Denominator
denominator=property(_den)
@property
def den(self):
return self.denominator
def __eq__(self, other):
from pymbolic.rational import Rational
return isinstance(other, (Rational, Quotient)) \
and (self._Numerator == other.numerator) \
and (self._Denominator == other.denominator)
and (self.numerator == other.numerator) \
and (self.denominator == other.denominator)
def __nonzero__(self):
return bool(self._Numerator)
return bool(self.numerator)
def __hash__(self):
return 0xabc ^ hash(self.numerator) ^ hash(self.denominator)
......@@ -408,26 +398,21 @@ class Quotient(Expression):
def get_mapper_method(self, mapper):
return mapper.map_quotient
class Power(Expression):
def __init__(self, base, exponent):
self._Base = base
self._Exponent = exponent
self.base = base
self.exponent = exponent
def __getinitargs__(self):
return self._Base, self._Exponent
def _base(self):
return self._Base
base = property(_base)
def _exponent(self):
return self._Exponent
exponent = property(_exponent)
return self.base, self.exponent
def __eq__(self, other):
return isinstance(other, Power) \
and (self._Base == other._Base) \
and (self._Exponent == other._Exponent)
and (self.base == other.base) \
and (self.exponent == other.exponent)
def __hash__(self):
return 0xdef ^ hash(self.base) ^ hash(self.exponent)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment