diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 6e1519eff9e7dfe558384ef9bb165a0e2ab0b006..44057dfb816ee9e374ef919b5061361d878aa80b 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -225,16 +225,20 @@ class StringifyMapper(pymbolic.mapper.Mapper): def map_left_shift(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( + # +1 to address + # https://gitlab.tiker.net/inducer/pymbolic/issues/6 self.format("%s << %s", - self.rec(expr.shiftee, PREC_SHIFT, *args, **kwargs), - self.rec(expr.shift, PREC_SHIFT, *args, **kwargs)), + self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), enclosing_prec, PREC_SHIFT) def map_right_shift(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( + # +1 to address + # https://gitlab.tiker.net/inducer/pymbolic/issues/6 self.format("%s >> %s", - self.rec(expr.shiftee, PREC_SHIFT, *args, **kwargs), - self.rec(expr.shift, PREC_SHIFT, *args, **kwargs)), + self.rec(expr.shiftee, PREC_SHIFT+1, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT+1, *args, **kwargs)), enclosing_prec, PREC_SHIFT) def map_bitwise_not(self, expr, enclosing_prec, *args, **kwargs): diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 44cd8cd4371d0d038293428387a52b796418d03a..f6b8733dc4d69ad87bb47985a605b116e4a6a699 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -22,9 +22,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import pymbolic import pymbolic.primitives as prim import pytest +from pymbolic import parse from pymbolic.mapper import IdentityMapper @@ -35,9 +35,6 @@ except NameError: from functools import reduce -pymbolic.disable_subscript_by_getitem() - - def test_integer_power(): from pymbolic.algorithm import integer_power @@ -487,6 +484,14 @@ def test_long_sympy_mapping(): SympyToPymbolicMapper()(sp.sympify(int(10))) +def test_stringifier_preserve_shift_order(): + for expr in [ + parse("(a << b) >> 2"), + parse("a << (b >> 2)") + ]: + assert parse(str(expr)) == expr + + if __name__ == "__main__": import sys if len(sys.argv) > 1: