diff --git a/pymbolic/mapper/evaluator.py b/pymbolic/mapper/evaluator.py index 6d29667173c935ac8e22e67b0fe8b6965fad817c..f190a18e433243a326f72eb9ec4c533a6ebb8858 100644 --- a/pymbolic/mapper/evaluator.py +++ b/pymbolic/mapper/evaluator.py @@ -43,6 +43,9 @@ class EvaluationMapper(RecursiveMapper): def map_quotient(self, expr): return self.rec(expr.numerator) / self.rec(expr.denominator) + def map_floor_div(self, expr): + return self.rec(expr.numerator) // self.rec(expr.denominator) + def map_power(self, expr): return self.rec(expr.base) ** self.rec(expr.exponent) @@ -85,6 +88,12 @@ class EvaluationMapper(RecursiveMapper): else: return self.rec(expr.else_) + def map_min(self, expr): + return min(self.rec(child) for child in expr.children) + + def map_max(self, expr): + return min(self.rec(child) for child in expr.children) + diff --git a/pymbolic/mapper/flop_counter.py b/pymbolic/mapper/flop_counter.py index 3cd124598d06eb6f442f1fcd4e92826c2783bbcd..d59238f2cad8507c9292f852780a932b31cde5dc 100644 --- a/pymbolic/mapper/flop_counter.py +++ b/pymbolic/mapper/flop_counter.py @@ -24,6 +24,8 @@ class FlopCounter(CombineMapper): def map_quotient(self, expr, *args): return 1 + self.rec(expr.numerator) + self.rec(expr.denominator) + map_floor_div = map_quotient + def map_power(self, expr, *args): return 1 + self.rec(expr.base) + self.rec(expr.exponent) diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 2e215b43bf766586bc009a7340b3ff5ad460f6b8..640b8051aa370613ed26949464f0811448d36e16 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -168,6 +168,12 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): self.rec(expr.then, PREC_NONE), self.rec(expr.else_, PREC_NONE)) + def map_min(self, expr, enclosing_prec): + what = type(expr).__name__.lower() + return self.format("%s(%s)", what, self.join_rec(", ", expr.children, PREC_NONE)) + + map_max = map_min + diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 79804b6fd0b2216175ee2e04a5d4a8b39a051590..bae27e304dc8fed8bf849c884201b78d203b8eb5 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -115,12 +115,12 @@ class Expression(object): return self return Remainder(self, other) - def __rmod(self, other): + def __rmod__(self, other): if not is_valid_operand(other): return NotImplemented return Remainder(other, self) - + def __pow__(self, other): if not is_valid_operand(other): return NotImplemented @@ -658,6 +658,29 @@ class IfPositive(Expression): +class _MinMaxBase(Expression): + def __init__(self, children): + self.children = children + + def __getinitargs__(self): + return self.children + + def is_equal(self, other): + return (isinstance(other, type(self)) + and self.children == other.children) + + def get_hash(self): + return hash((type(self), self.children)) + +class Min(_MinMaxBase): + mapper_method = intern("map_min") + +class Max(_MinMaxBase): + mapper_method = intern("map_max") + + + + # intelligent makers --------------------------------------------------------- def make_variable(var_or_string): if not isinstance(var_or_string, Expression):