From 9d71e9d480b813c73d537de8d2c5b78e40bc3616 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 11 Jun 2009 00:57:56 -0400
Subject: [PATCH] Introduce remainder into Pymbolic's expression language.

---
 pymbolic/mapper/stringifier.py |  7 ++++++
 pymbolic/primitives.py         | 43 +++++++++++++++++++++++++++++-----
 2 files changed, 44 insertions(+), 6 deletions(-)

diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py
index 1926414..2b79ace 100644
--- a/pymbolic/mapper/stringifier.py
+++ b/pymbolic/mapper/stringifier.py
@@ -106,6 +106,13 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper):
                     self.rec(expr.exponent, PREC_POWER)),
                 enclosing_prec, PREC_POWER)
 
+    def map_remainder(self, expr, enclosing_prec):
+        return self.parenthesize_if_needed(
+                self.format("%s %% %s", 
+                    self.rec(expr.numerator, PREC_PRODUCT), 
+                    self.rec(expr.denominator, PREC_POWER)), # analogous to ^{-1}
+                enclosing_prec, PREC_PRODUCT)
+
     def map_polynomial(self, expr, enclosing_prec):
         from pymbolic.primitives import flattened_sum
         return self.rec(flattened_sum(
diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py
index f16303f..7a654f1 100644
--- a/pymbolic/primitives.py
+++ b/pymbolic/primitives.py
@@ -89,6 +89,20 @@ class Expression(object):
         return quotient(other, self)
     __rtruediv__ = __rdiv__
 
+    def __mod__(self, other):
+        if not is_valid_operand(other):
+            return NotImplemented
+
+        if is_zero(other-1):
+            return self
+        return Remainder(self, other)
+
+    def __rmod(self, other):
+        if not is_valid_operand(other):
+            return NotImplemented
+
+        return Remainder(other, self)
+    
     def __pow__(self, other):
         if not is_valid_operand(other):
             return NotImplemented
@@ -442,7 +456,7 @@ class Product(Expression):
 
 
 
-class Quotient(Expression):
+class QuotientBase(Expression):
     def __init__(self, numerator, denominator=1):
         self.numerator = numerator
         self.denominator = denominator
@@ -458,20 +472,37 @@ class Quotient(Expression):
     def den(self):
         return self.denominator
 
+    def __nonzero__(self):
+        return bool(self.numerator)
+
+    def get_hash(self):
+        return hash((self.__class__, self.numerator, self.denominator))
+
+
+
+
+class Quotient(QuotientBase):
     def is_equal(self, other):
         from pymbolic.rational import Rational
         return isinstance(other, (Rational, Quotient)) \
                and (self.numerator == other.numerator) \
                and (self.denominator == other.denominator)
 
-    def __nonzero__(self):
-        return bool(self.numerator)
+    def get_mapper_method(self, mapper):
+        return mapper.map_quotient
 
-    def get_hash(self):
-        return hash((self.__class__, self.numerator, self.denominator))
+
+
+
+class Remainder(QuotientBase):
+    def is_equal(self, other):
+        from pymbolic.rational import Rational
+        return self.__class__ == other.__class__ \
+               and (self.numerator == other.numerator) \
+               and (self.denominator == other.denominator)
 
     def get_mapper_method(self, mapper):
-        return mapper.map_quotient
+        return mapper.map_remainder
 
 
 
-- 
GitLab