From 60aae072e4491e2a64b6427d650cb99b7ee9c5d5 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 24 Sep 2007 22:54:57 -0400 Subject: [PATCH] Remove unnecessary property functions. Incremental improvements. Remove is_constant from differentiator. Trivial fixes. --- src/mapper/dependency.py | 4 +- src/mapper/differentiator.py | 52 +++++----- src/mapper/evaluator.py | 6 ++ src/primitives.py | 177 ++++++++++++++++------------------- 4 files changed, 110 insertions(+), 129 deletions(-) diff --git a/src/mapper/dependency.py b/src/mapper/dependency.py index d42baac..f4b450e 100644 --- a/src/mapper/dependency.py +++ b/src/mapper/dependency.py @@ -24,13 +24,13 @@ class DependencyMapper(CombineMapper): def map_variable(self, expr): return set([expr]) - def map_call(self, expr, *args, **kwargs): + def map_call(self, expr): if self.IncludeCalls: return set([expr]) else: return CombineMapper.map_call(self, expr) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, expr): if self.IncludeLookups: return set([expr]) else: diff --git a/src/mapper/differentiator.py b/src/mapper/differentiator.py index 74a3573..e19edc3 100644 --- a/src/mapper/differentiator.py +++ b/src/mapper/differentiator.py @@ -34,25 +34,25 @@ def map_math_functions_by_name(i, func, pars): class DifferentiationMapper(pymbolic.mapper.RecursiveMapper): - def __init__(self, variable, func_map): - self.Variable = variable - self.FunctionMap = func_map + def __init__(self, variable, func_map=map_math_functions_by_name): + self.variable = variable + self.function_map = func_map def map_constant(self, expr): return 0 def map_variable(self, expr): - if expr == self.Variable: + if expr == self.variable: return 1 else: return 0 def map_call(self, expr): - return pymbolic.sum( - self.FunctionMap(i, expr.function, expr.parameters) + return pymbolic.flattened_sum( + self.function_map(i, expr.function, expr.parameters) * self.rec(par) for i, par in enumerate(expr.parameters) - if not self._isc(par)) + ) map_subscript = map_variable @@ -60,29 +60,27 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper): return -self.rec(expr.child) def map_sum(self, expr): - return pymbolic.sum(self.rec(child) for child in expr.children - if not self._isc(child)) + return pymbolic.flattened_sum(self.rec(child) for child in expr.children) def map_product(self, expr): - return pymbolic.sum( - pymbolic.product( + return pymbolic.flattened_sum( + pymbolic.flattened_product( expr.children[0:i] + (self.rec(child),) + expr.children[i+1:]) - for i, child in enumerate(expr.children) - if not self._isc(child)) + for i, child in enumerate(expr.children)) def map_quotient(self, expr): f = expr.numerator g = expr.denominator - f_const = self._isc(f) - g_const = self._isc(g) + df = self.rec(f) + dg = self.rec(g) - if f_const and g_const: + if (not df) and (not dg): return 0 - elif f_const: + elif (not df): return -f*self.rec(g)/g**2 - elif g_const: + elif (not dg): return self.rec(f)/g else: return (self.rec(f)*g-self.rec(g)*f)/g**2 @@ -90,16 +88,16 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper): def map_power(self, expr): f = expr.base g = expr.exponent - f_const = self._isc(f) - g_const = self._isc(g) + df = self.rec(f) + dg = self.rec(g) log = pymbolic.var("log") - if f_const and g_const: + if (not df) and (not dg): return 0 - elif f_const: + elif (not df): return log(f) * f**g * self.rec(g) - elif g_const: + elif (not dg): return g * f**(g-1) * self.rec(f) else: return log(f) * f**g * self.rec(g) + \ @@ -124,14 +122,6 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper): return \ Polynomial(expr.base, tuple(deriv_coeff), expr.unit) + \ Polynomial(expr.base, tuple(deriv_base), expr.unit) - - - - - def _isc(self,subexp): - return pymbolic.is_constant(subexp, [self.Variable], - include_lookups=True, include_subscripts=True - ) diff --git a/src/mapper/evaluator.py b/src/mapper/evaluator.py index f205a0c..6886333 100644 --- a/src/mapper/evaluator.py +++ b/src/mapper/evaluator.py @@ -74,6 +74,12 @@ class EvaluationMapper(RecursiveMapper): class FloatEvaluationMapper(EvaluationMapper): + def handle_unsupported_expression(self, expr): + try: + return float(expr) + except: + raise TypeError, "cannot convert %s to float" % type(expr) + def map_constant(self, expr): return float(expr) diff --git a/src/primitives.py b/src/primitives.py index cd88034..6797572 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -154,31 +154,38 @@ class Expression(object): return "%s(%s)" % (self.__class__.__name__, initargs_str) + + + + class AlgebraicLeaf(Expression): pass + + + class Leaf(AlgebraicLeaf): pass + + + class Variable(Leaf): def __init__(self, name): - self._Name = name + self.name = name def __getinitargs__(self): return self._Name, - def _name(self): - return self._Name - name = property(_name) - def __lt__(self, other): if isinstance(other, Variable): - return self._Name.__lt__(other._Name) + return self.name.__lt__(other.name) else: return NotImplemented def __eq__(self, other): - return isinstance(other, Variable) and self._Name == other._Name + return (isinstance(other, Variable) + and self.name == other.name) def __hash__(self): return 0x111 ^ hash(self.name) @@ -188,24 +195,16 @@ class Variable(Leaf): class Call(AlgebraicLeaf): def __init__(self, function, parameters): - self._Function = function - self._Parameters = parameters + self.function = function + self.parameters = parameters def __getinitargs__(self): - return self._Function, self._Parameters - - def _function(self): - return self._Function - function = property(_function) - - def _parameters(self): - return self._Parameters - parameters = property(_parameters) + return self.function, self.parameters def __eq__(self, other): return isinstance(other, Call) \ - and (self._Function == other._Function) \ - and (self._Parameters == other._Parameters) + and (self.function == other.function) \ + and (self.parameters == other.parameters) def __hash__(self): return hash(self.function) ^ hash(self.parameters) @@ -213,26 +212,21 @@ class Call(AlgebraicLeaf): def get_mapper_method(self, mapper): return mapper.map_call + + + class Subscript(AlgebraicLeaf): def __init__(self, aggregate, index): - self._Aggregate = aggregate - self._Index = index + self.aggregate = aggregate + self.index = index def __getinitargs__(self): return self._Aggregate, self._Index - def _aggregate(self): - return self._Aggregate - aggregate = property(_aggregate) - - def _index(self): - return self._Index - index = property(_index) - def __eq__(self, other): return isinstance(other, Subscript) \ - and (self._Aggregate == other._Aggregate) \ - and (self._Index == other._Index) + and (self.aggregate == other.aggregate) \ + and (self.index == other.index) def __hash__(self): return 0x123 ^ hash(self.aggregate) ^ hash(self.index) @@ -242,26 +236,19 @@ class Subscript(AlgebraicLeaf): + class ElementLookup(AlgebraicLeaf): def __init__(self, aggregate, name): - self._Aggregate = aggregate - self._Name = name + self.aggregate = aggregate + self.name = name def __getinitargs__(self): return self._Aggregate, self._Name - def _aggregate(self): - return self._Aggregate - aggregate = property(_aggregate) - - def _name(self): - return self._Name - name = property(_name) - def __eq__(self, other): return isinstance(other, ElementLookup) \ - and (self._Aggregate == other._Aggregate) \ - and (self._Name == other._Name) + and (self.aggregate == other.aggregate) \ + and (self.name == other.name) def __hash__(self): return 0x183 ^ hash(self.aggregate) ^ hash(self.name) @@ -269,41 +256,41 @@ class ElementLookup(AlgebraicLeaf): def get_mapper_method(self, mapper): return mapper.map_lookup + + + class Sum(Expression): def __init__(self, children): assert isinstance(children, tuple) - self._Children = children + self.children = children def __getinitargs__(self): - return self._Children - - def _children(self): - return self._Children - children = property(_children) + return self.children def __eq__(self, other): - return isinstance(other, Sum) and (self._Children == other._Children) + return (isinstance(other, Sum) + and (set(self.children) == set(other.children))) def __add__(self, other): if not is_valid_operand(other): return NotImplemented if isinstance(other, Sum): - return Sum(self._Children + other._Children) + return Sum(self.children + other.children) if not other: return self - return Sum(self._Children + (other,)) + return Sum(self.children + (other,)) def __radd__(self, other): if not is_constant(other): return NotImplemented if isinstance(other, Sum): - return Sum(other._Children + self._Children) + return Sum(other.children + self.children) if not other: return self - return Sum((other,) + self._Children) + return Sum((other,) + self.children) def __sub__(self, other): if not is_valid_operand(other): @@ -311,13 +298,13 @@ class Sum(Expression): if not other: return self - return Sum(self._Children + (-other,)) + return Sum(self.children + (-other,)) def __nonzero__(self): - if len(self._Children) == 0: + if len(self.children) == 0: return True - elif len(self._Children) == 1: - return bool(self._Children[0]) + elif len(self.children) == 1: + return bool(self.children[0]) else: # FIXME: Right semantics? return True @@ -328,45 +315,45 @@ class Sum(Expression): def get_mapper_method(self, mapper): return mapper.map_sum + + + class Product(Expression): def __init__(self, children): assert isinstance(children, tuple) - self._Children = children + self.children = children def __getinitargs__(self): - return self._Children - - def _children(self): - return self._Children - children = property(_children) + return self.children def __eq__(self, other): - return isinstance(other, Product) and (self._Children == other._Children) + return (isinstance(other, Product) + and (set(self.children) == set(other.children))) def __mul__(self, other): if not is_valid_operand(other): return NotImplemented if isinstance(other, Product): - return Product(self._Children + other._Children) + return Product(self.children + other.children) if not other: return 0 if not other-1: return self - return Product(self._Children + (other,)) + return Product(self.children + (other,)) def __rmul__(self, other): if not is_constant(other): return NotImplemented if isinstance(other, Product): - return Product(other._Children + self._Children) + return Product(other.children + self.children) if not other: return 0 if not other-1: return self - return Product((other,) + self._Children) + return Product((other,) + self.children) def __nonzero__(self): - for i in self._Children: + for i in self.children: if not i: return False return True @@ -377,30 +364,33 @@ class Product(Expression): def get_mapper_method(self, mapper): return mapper.map_product + + + class Quotient(Expression): def __init__(self, numerator, denominator=1): - self._Numerator = numerator - self._Denominator = denominator + self.numerator = numerator + self.denominator = denominator def __getinitargs__(self): - return self._Numerator, self._Denominator + return self.numerator, self.denominator - def _num(self): - return self._Numerator - numerator=property(_num) + @property + def num(self): + return self.numerator - def _den(self): - return self._Denominator - denominator=property(_den) + @property + def den(self): + return self.denominator def __eq__(self, other): from pymbolic.rational import Rational return isinstance(other, (Rational, Quotient)) \ - and (self._Numerator == other.numerator) \ - and (self._Denominator == other.denominator) + and (self.numerator == other.numerator) \ + and (self.denominator == other.denominator) def __nonzero__(self): - return bool(self._Numerator) + return bool(self.numerator) def __hash__(self): return 0xabc ^ hash(self.numerator) ^ hash(self.denominator) @@ -408,26 +398,21 @@ class Quotient(Expression): def get_mapper_method(self, mapper): return mapper.map_quotient + + + class Power(Expression): def __init__(self, base, exponent): - self._Base = base - self._Exponent = exponent + self.base = base + self.exponent = exponent def __getinitargs__(self): - return self._Base, self._Exponent - - def _base(self): - return self._Base - base = property(_base) - - def _exponent(self): - return self._Exponent - exponent = property(_exponent) + return self.base, self.exponent def __eq__(self, other): return isinstance(other, Power) \ - and (self._Base == other._Base) \ - and (self._Exponent == other._Exponent) + and (self.base == other.base) \ + and (self.exponent == other.exponent) def __hash__(self): return 0xdef ^ hash(self.base) ^ hash(self.exponent) -- GitLab