diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index f86eaf24c963db756588874be6d4c93caba2329f..24be3ecf95afdc248a5a5fd54bca0020a870b8bd 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -778,7 +778,7 @@ class Sum(_MultiChildExpression): return self return Sum(self.children + (-other,)) - def __nonzero__(self): + def __bool__(self): if len(self.children) == 0: return True elif len(self.children) == 1: @@ -787,6 +787,8 @@ class Sum(_MultiChildExpression): # FIXME: Right semantics? return True + __nonzero__ = __bool__ + mapper_method = intern("map_sum") @@ -819,12 +821,14 @@ class Product(_MultiChildExpression): return self return Product((other,) + self.children) - def __nonzero__(self): + def __bool__(self): for i in self.children: if is_zero(i): return False return True + __nonzero__ = __bool__ + mapper_method = intern("map_product") @@ -846,9 +850,11 @@ class QuotientBase(Expression): def den(self): return self.denominator - def __nonzero__(self): + def __bool__(self): return bool(self.numerator) + __nonzero__ = __bool__ + class Quotient(QuotientBase): """ @@ -1131,12 +1137,14 @@ class Vector(Expression): "(depending on the required semantics)", DeprecationWarning) - def __nonzero__(self): + def __bool__(self): for i in self.children: if is_nonzero(i): return False return True + __nonzero__ = __bool__ + def __len__(self): return len(self.children) @@ -1305,9 +1313,11 @@ class Slice(Expression): def __getinitargs__(self): return (self.children,) - def __nonzero__(self): + def __bool__(self): return True + __nonzero__ = __bool__ + @property def start(self): if len(self.children) > 1: diff --git a/pymbolic/rational.py b/pymbolic/rational.py index 7296881c9877b922778e600d38293012ea8b85de..7ab677f60fd64b3f906b84c537f5184b367565ae 100644 --- a/pymbolic/rational.py +++ b/pymbolic/rational.py @@ -27,8 +27,6 @@ import pymbolic.primitives as primitives import pymbolic.traits as traits - - class Rational(primitives.Expression): def __init__(self, numerator, denominator=1): d_unit = traits.traits(denominator).get_unit(denominator) @@ -45,9 +43,11 @@ class Rational(primitives.Expression): return self.Denominator denominator = property(_den) - def __nonzero__(self): + def __bool__(self): return bool(self.Numerator) + __nonzero__ = __bool__ + def __neg__(self): return Rational(-self.Numerator, self.Denominator) @@ -132,8 +132,6 @@ class Rational(primitives.Expression): mapper_method = intern("map_rational") - - if __name__ == "__main__": one = Rational(1) print(3 + 1/(1 - 3/(one + 17)))