diff --git a/pytato/array.py b/pytato/array.py index 6a83a662ddeb774c0bd31f23ed84940fb9816712..6883e9837e7f50acbd8b6261bba982ce91fd3624 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1861,7 +1861,8 @@ def _get_reduction_indices_bounds(shape: ShapeType, indices.append(prim.Variable(f"_{n_out_dims}")) n_out_dims += 1 - return indices, redn_bounds + from pyrsistent import pmap + return indices, pmap(redn_bounds) def _reduction_lambda(op: str, a: Array, diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 7df2bbb7a0c44694b4f7cbca0df867e9933eb533..482709b66763310ad074970f59583c0bf419f1b4 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -25,7 +25,7 @@ THE SOFTWARE. """ from numbers import Number -from typing import Any, Union, Mapping, FrozenSet, Set, Tuple, Dict +from typing import Any, Union, Mapping, FrozenSet, Set, Tuple from pymbolic.mapper import (WalkMapper as WalkMapperBase, IdentityMapper as IdentityMapperBase) @@ -37,10 +37,11 @@ from pymbolic.mapper.evaluator import (EvaluationMapper as EvaluationMapperBase) from pymbolic.mapper.distributor import (DistributeMapper as DistributeMapperBase) +from pymbolic.mapper.stringifier import (StringifyMapper as + StringifyMapperBase) from pymbolic.mapper.collector import TermCollector as TermCollectorBase import pymbolic.primitives as prim import numpy as np -from dataclasses import dataclass, field __doc__ = """ .. currentmodule:: pytato.scalar_expr @@ -129,6 +130,18 @@ class DistributeMapper(DistributeMapperBase): class TermCollector(TermCollectorBase): pass + +class StringifyMapper(StringifyMapperBase): + def map_reduce(self, expr, enclosing_prec, *args) -> str: + from pymbolic.mapper.stringifier import ( + PREC_COMPARISON as PC, + PREC_NONE as PN) + bounds_expr = " and ".join( + f"{self.rec(lb, PC)}<={name}<{self.rec(ub, PC)}" + for name, (lb, ub) in expr.bounds.items()) + bounds_expr = "{" + bounds_expr + "}" + return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})") + # }}} @@ -173,24 +186,42 @@ def distribute(expr: Any, parameters: Set[Any] = set(), # }}} -@dataclass -class Reduce(prim.Expression): +# {{{ custom scalar expression nodes + +class ExpressionBase(prim.Expression): + def make_stringifier(self, originating_stringifier=None): + return StringifyMapper() + + +class Reduce(ExpressionBase): + """ + .. attribute:: inner_expr + + A :class:`ScalarExpression` to be reduced over. + + .. attribute:: op + + One of ``"sum"``, ``"product"``, ``"max"``, ``"min"``. + + .. attribute:: bounds + + A mapping from reduction inames to tuples ``(lower_bound, upper_bound)`` + identifying half-open bounds intervals. Must be hashable. + """ inner_expr: ScalarExpression op: str - bounds: Dict[str, Tuple[ScalarExpression, ScalarExpression]] - mapper_method: str = field(init=False, default="map_reduce") - - def __hash__(self) -> int: - return hash((self.inner_expr, - self.op, - tuple(self.bounds.keys()), - tuple(self.bounds.values()))) - - def __str__(self) -> str: - bounds_expr = " and ".join(f"{lb}<={key}<{ub}" - for key, (lb, ub) in self.bounds.items()) - bounds_expr = "{" + bounds_expr + "}" - return (f"{self.op}({bounds_expr}, {self.inner_expr})") + bounds: Mapping[str, Tuple[ScalarExpression, ScalarExpression]] + def __init__(self, inner_expr, op, bounds): + self.inner_expr = inner_expr + self.op = op + self.bounds = bounds + + def __getinitargs__(self): + return (self.inner_expr, self.op, self.bounds) + + mapper_method = "map_reduce" + +# }}} # vim: foldmethod=marker diff --git a/setup.py b/setup.py index 3c4e24c56349e611c93db1db0303501dfb6cd247..93d5b88c9f1bdccb513e3336fbfe23d1b8aaef29 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,11 @@ setup( "Topic :: Software Development :: Libraries", ], python_requires="~=3.8", - install_requires=["loopy>=2020.2", "pytools>=2021.1"], + install_requires=[ + "loopy>=2020.2", + "pytools>=2021.1", + "pyrsistent" + ], author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei", url="https://github.com/inducer/pytato", author_email="inform@tiker.net",