From 9d71e9d480b813c73d537de8d2c5b78e40bc3616 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 11 Jun 2009 00:57:56 -0400 Subject: [PATCH] Introduce remainder into Pymbolic's expression language. --- pymbolic/mapper/stringifier.py | 7 ++++++ pymbolic/primitives.py | 43 +++++++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 1926414..2b79ace 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -106,6 +106,13 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): self.rec(expr.exponent, PREC_POWER)), enclosing_prec, PREC_POWER) + def map_remainder(self, expr, enclosing_prec): + return self.parenthesize_if_needed( + self.format("%s %% %s", + self.rec(expr.numerator, PREC_PRODUCT), + self.rec(expr.denominator, PREC_POWER)), # analogous to ^{-1} + enclosing_prec, PREC_PRODUCT) + def map_polynomial(self, expr, enclosing_prec): from pymbolic.primitives import flattened_sum return self.rec(flattened_sum( diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index f16303f..7a654f1 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -89,6 +89,20 @@ class Expression(object): return quotient(other, self) __rtruediv__ = __rdiv__ + def __mod__(self, other): + if not is_valid_operand(other): + return NotImplemented + + if is_zero(other-1): + return self + return Remainder(self, other) + + def __rmod(self, other): + if not is_valid_operand(other): + return NotImplemented + + return Remainder(other, self) + def __pow__(self, other): if not is_valid_operand(other): return NotImplemented @@ -442,7 +456,7 @@ class Product(Expression): -class Quotient(Expression): +class QuotientBase(Expression): def __init__(self, numerator, denominator=1): self.numerator = numerator self.denominator = denominator @@ -458,20 +472,37 @@ class Quotient(Expression): def den(self): return self.denominator + def __nonzero__(self): + return bool(self.numerator) + + def get_hash(self): + return hash((self.__class__, self.numerator, self.denominator)) + + + + +class Quotient(QuotientBase): def is_equal(self, other): from pymbolic.rational import Rational return isinstance(other, (Rational, Quotient)) \ and (self.numerator == other.numerator) \ and (self.denominator == other.denominator) - def __nonzero__(self): - return bool(self.numerator) + def get_mapper_method(self, mapper): + return mapper.map_quotient - def get_hash(self): - return hash((self.__class__, self.numerator, self.denominator)) + + + +class Remainder(QuotientBase): + def is_equal(self, other): + from pymbolic.rational import Rational + return self.__class__ == other.__class__ \ + and (self.numerator == other.numerator) \ + and (self.denominator == other.denominator) def get_mapper_method(self, mapper): - return mapper.map_quotient + return mapper.map_remainder -- GitLab