diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 192641452c904835285674edc5e915d4045c6fe2..2b79ace0d89e850265afaa440c3309dfbc204c15 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -106,6 +106,13 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): self.rec(expr.exponent, PREC_POWER)), enclosing_prec, PREC_POWER) + def map_remainder(self, expr, enclosing_prec): + return self.parenthesize_if_needed( + self.format("%s %% %s", + self.rec(expr.numerator, PREC_PRODUCT), + self.rec(expr.denominator, PREC_POWER)), # analogous to ^{-1} + enclosing_prec, PREC_PRODUCT) + def map_polynomial(self, expr, enclosing_prec): from pymbolic.primitives import flattened_sum return self.rec(flattened_sum( diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index f16303f433b64b0b1e9e984ae434fee0ef32f1bc..7a654f1d482aeaa2edb9cc07df14c2837897de48 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -89,6 +89,20 @@ class Expression(object): return quotient(other, self) __rtruediv__ = __rdiv__ + def __mod__(self, other): + if not is_valid_operand(other): + return NotImplemented + + if is_zero(other-1): + return self + return Remainder(self, other) + + def __rmod(self, other): + if not is_valid_operand(other): + return NotImplemented + + return Remainder(other, self) + def __pow__(self, other): if not is_valid_operand(other): return NotImplemented @@ -442,7 +456,7 @@ class Product(Expression): -class Quotient(Expression): +class QuotientBase(Expression): def __init__(self, numerator, denominator=1): self.numerator = numerator self.denominator = denominator @@ -458,20 +472,37 @@ class Quotient(Expression): def den(self): return self.denominator + def __nonzero__(self): + return bool(self.numerator) + + def get_hash(self): + return hash((self.__class__, self.numerator, self.denominator)) + + + + +class Quotient(QuotientBase): def is_equal(self, other): from pymbolic.rational import Rational return isinstance(other, (Rational, Quotient)) \ and (self.numerator == other.numerator) \ and (self.denominator == other.denominator) - def __nonzero__(self): - return bool(self.numerator) + def get_mapper_method(self, mapper): + return mapper.map_quotient - def get_hash(self): - return hash((self.__class__, self.numerator, self.denominator)) + + + +class Remainder(QuotientBase): + def is_equal(self, other): + from pymbolic.rational import Rational + return self.__class__ == other.__class__ \ + and (self.numerator == other.numerator) \ + and (self.denominator == other.denominator) def get_mapper_method(self, mapper): - return mapper.map_quotient + return mapper.map_remainder