diff --git a/pytato/array.py b/pytato/array.py index 9fdf1de028ca90342145c646176482c58f5c39f5..64f47890fbc3f6f81f68d261303477f66f0eb720 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -295,7 +295,7 @@ def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype[A return dtype -@attrs.define(frozen=True) +@attrs.frozen class NormalizedSlice: """ A normalized version of :class:`slice`. "Normalized" is explained in @@ -320,7 +320,7 @@ class NormalizedSlice: step: IntegralT -@attrs.define(frozen=True) +@attrs.frozen class Axis(Taggable): """ A type for recording the information about an :class:`~pytato.Array`'s @@ -333,7 +333,7 @@ class Axis(Taggable): return replace(self, tags=tags) -@attrs.define(frozen=True) +@attrs.frozen class ReductionDescriptor(Taggable): """ Records information about a reduction dimension in an @@ -346,7 +346,7 @@ class ReductionDescriptor(Taggable): return replace(self, tags=tags) -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class Array(Taggable): r""" A base class (abstract interface + supplemental functionality) for lazily @@ -441,41 +441,32 @@ class Array(Taggable): _mapper_method: ClassVar[str] - # A tuple of field names. Fields must be equality comparable and - # hashable. Dicts of hashable keys and values are also permitted. - _fields: ClassVar[Tuple[str, ...]] = ("axes", "tags",) - # disallow numpy arithmetic from taking precedence __array_priority__: ClassVar[int] = 1 def _is_eq_valid(self) -> bool: - return (self.__class__.__eq__ is Array.__eq__ - and self.__class__.__hash__ is Array.__hash__) - - def __post_init__(self) -> None: - # ensure that a developer does not uses dataclass' "__eq__" - # or "__hash__" implementation as they have exponential complexity. - assert self._is_eq_valid() + return self.__class__.__eq__ is Array.__eq__ - def __attrs_post_init__(self) -> None: - return self.__post_init__() + if __debug__: + def __attrs_post_init__(self) -> None: + # ensure that a developer does not uses dataclass' "__eq__" + # or "__hash__" implementation as they have exponential complexity. + assert self._is_eq_valid() def copy(self: ArrayT, **kwargs: Any) -> ArrayT: - for field in self._fields: - if field not in kwargs: - kwargs[field] = getattr(self, field) - return type(self)(**kwargs) + return attrs.evolve(self, **kwargs) def _with_new_tags(self: ArrayT, tags: FrozenSet[Tag]) -> ArrayT: - return self.copy(tags=tags) + return attrs.evolve(self, tags=tags) - @property - def shape(self) -> ShapeType: - raise NotImplementedError() + if TYPE_CHECKING: + @property + def shape(self) -> ShapeType: + raise NotImplementedError - @property - def dtype(self) -> _dtype_any: - raise NotImplementedError() + @property + def dtype(self) -> np.dtype[Any]: + raise NotImplementedError @property def size(self) -> ShapeComponent: @@ -515,16 +506,6 @@ class Array(Taggable): tags=_get_default_tags(), axes=_get_default_axes(self.ndim)) - @memoize_method - def __hash__(self) -> int: - attrs = [] - for field in self._fields: - attr = getattr(self, field) - if isinstance(attr, dict): - attr = frozenset(attr.items()) - attrs.append(attr) - return hash(tuple(attrs)) - def __eq__(self, other: Any) -> bool: if self is other: return True @@ -684,27 +665,20 @@ class Array(Taggable): # {{{ mixins +@attrs.frozen(eq=False, slots=False, repr=False) class _SuppliedShapeAndDtypeMixin: """A mixin class for when an array must store its own *shape* and *dtype*, rather than when it can derive them easily from inputs. """ - _shape: ShapeType - _dtype: np.dtype[Any] - - @property - def shape(self) -> ShapeType: - return self._shape - - @property - def dtype(self) -> np.dtype[Any]: - return self._dtype + shape: ShapeType + dtype: np.dtype[Any] # }}} # {{{ dict of named arrays -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class NamedArray(Array): """An entry in a :class:`AbstractResultWithNamedArrays`. Holds a reference back to the containing instance as well as the name by which *self* is @@ -715,7 +689,6 @@ class NamedArray(Array): _container: AbstractResultWithNamedArrays name: str - _fields: ClassVar[Tuple[str, ...]] = ("_container", "name", "axes", "tags",) _mapper_method: ClassVar[str] = "map_named_array" # type-ignore reason: `copy` signature incompatible with super-class @@ -750,7 +723,7 @@ class NamedArray(Array): return self.expr.dtype -@attrs.define(frozen=True, eq=False) +@attrs.frozen(eq=False, hash=True, cache_hash=True) class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC): r"""An abstract array computation that results in multiple :class:`Array`\ s, each named. The way in which the values of these arrays are computed @@ -772,14 +745,11 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC): def _is_eq_valid(self) -> bool: return self.__class__.__eq__ is AbstractResultWithNamedArrays.__eq__ - def __post_init__(self) -> None: + def __attrs_post_init__(self) -> None: # ensure that a developer does not uses dataclass' "__eq__" # or "__hash__" implementation as they have exponential complexity. assert self._is_eq_valid() - def __attrs_post_init__(self) -> None: - return self.__post_init__() - @abstractmethod def __contains__(self, name: object) -> bool: pass @@ -800,7 +770,7 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC): return EqualityComparer()(self, other) -@attrs.define(frozen=True, eq=False, init=False) +@attrs.frozen(eq=False, init=False) class DictOfNamedArrays(AbstractResultWithNamedArrays): """A container of named results, each of which can be computed as an array expression provided to the constructor. @@ -853,7 +823,7 @@ class DictOfNamedArrays(AbstractResultWithNamedArrays): # {{{ index lambda -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class IndexLambda(_SuppliedShapeAndDtypeMixin, Array): r"""Represents an array that can be computed by evaluating :attr:`expr` for every value of the input indices. The @@ -889,16 +859,17 @@ class IndexLambda(_SuppliedShapeAndDtypeMixin, Array): .. automethod:: with_tagged_reduction """ expr: prim.Expression - _shape: ShapeType - _dtype: np.dtype[Any] - bindings: Mapping[str, Array] + bindings: Mapping[str, Array] = attrs.field() var_to_reduction_descr: Mapping[str, ReductionDescriptor] - _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("expr", "shape", "dtype", - "bindings", - "var_to_reduction_descr") _mapper_method: ClassVar[str] = "map_index_lambda" + if __debug__: + @bindings.validator # type: ignore[attr-defined, misc] + def _check_bindings(self, attribute: Any, value: Any) -> None: + if isinstance(value, dict): + raise TypeError("bindings may not be a dict") + def with_tagged_reduction(self, reduction_variable: str, tag: Tag) -> IndexLambda: @@ -949,7 +920,7 @@ class EinsumAxisDescriptor: pass -@attrs.define(frozen=True) +@attrs.frozen class EinsumElementwiseAxis(EinsumAxisDescriptor): """ Describes an elementwise access pattern of an array's axis. In terms of the @@ -959,7 +930,7 @@ class EinsumElementwiseAxis(EinsumAxisDescriptor): dim: int -@attrs.define(frozen=True) +@attrs.frozen class EinsumReductionAxis(EinsumAxisDescriptor): """ Describes a reduction access pattern of an array's axis. In terms of the @@ -969,7 +940,7 @@ class EinsumReductionAxis(EinsumAxisDescriptor): dim: int -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(frozen=True, eq=False, repr=False, hash=True, cache_hash=True) class Einsum(Array): """ An array expression using the `Einstein summation convention @@ -1008,10 +979,6 @@ class Einsum(Array): redn_axis_to_redn_descr: Mapping[EinsumReductionAxis, ReductionDescriptor] index_to_access_descr: Mapping[str, EinsumAxisDescriptor] - _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("access_descriptors", - "args", - "redn_axis_to_redn_descr", - "index_to_access_descr") _mapper_method: ClassVar[str] = "map_einsum" @memoize_method @@ -1315,7 +1282,7 @@ def einsum(subscripts: str, *operands: Array, # {{{ stack -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class Stack(Array): """Join a sequence of arrays along a new axis. @@ -1331,7 +1298,6 @@ class Stack(Array): arrays: Tuple[Array, ...] axis: int - _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("arrays", "axis") _mapper_method: ClassVar[str] = "map_stack" @property @@ -1349,7 +1315,7 @@ class Stack(Array): # {{{ concatenate -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class Concatenate(Array): """Join a sequence of arrays along an existing axis. @@ -1365,7 +1331,6 @@ class Concatenate(Array): arrays: Tuple[Array, ...] axis: int - _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("arrays", "axis") _mapper_method: ClassVar[str] = "map_concatenate" @property @@ -1387,7 +1352,7 @@ class Concatenate(Array): # {{{ index remapping -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class IndexRemappingBase(Array): """Base class for operations that remap the indices of an array. @@ -1400,7 +1365,6 @@ class IndexRemappingBase(Array): """ array: Array - _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("array",) @property def dtype(self) -> np.dtype[Any]: @@ -1411,7 +1375,7 @@ class IndexRemappingBase(Array): # {{{ roll -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class Roll(IndexRemappingBase): """Roll an array along an axis. @@ -1426,8 +1390,6 @@ class Roll(IndexRemappingBase): shift: int axis: int - _fields: ClassVar[Tuple[str, ...]] = IndexRemappingBase._fields + ("shift", - "axis") _mapper_method: ClassVar[str] = "map_roll" @property @@ -1439,7 +1401,7 @@ class Roll(IndexRemappingBase): # {{{ axis permutation -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class AxisPermutation(IndexRemappingBase): r"""Permute the axes of an array. @@ -1451,8 +1413,6 @@ class AxisPermutation(IndexRemappingBase): """ axis_permutation: Tuple[int, ...] - _fields: ClassVar[Tuple[str, ...]] = (IndexRemappingBase._fields - + ("axis_permutation",)) _mapper_method: ClassVar[str] = "map_axis_permutation" @property @@ -1468,7 +1428,7 @@ class AxisPermutation(IndexRemappingBase): # {{{ reshape -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class Reshape(IndexRemappingBase): """ Reshape an array. @@ -1488,16 +1448,12 @@ class Reshape(IndexRemappingBase): newshape: ShapeType order: str - _fields: ClassVar[Tuple[str, ...]] = IndexRemappingBase._fields + ("newshape", - "order") _mapper_method: ClassVar[str] = "map_reshape" - def __post_init__(self) -> None: + def __attrs_post_init__(self) -> None: # FIXME: Get rid of this restriction assert self.order == "C" - super().__post_init__() - - __attrs_post_init__ = __post_init__ + super().__attrs_post_init__() @property def shape(self) -> ShapeType: @@ -1508,7 +1464,7 @@ class Reshape(IndexRemappingBase): # {{{ indexing -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class IndexBase(IndexRemappingBase): """ Abstract class for all index expressions on an array. @@ -1516,7 +1472,6 @@ class IndexBase(IndexRemappingBase): .. attribute:: indices """ indices: Tuple[IndexExpr, ...] - _fields: ClassVar[Tuple[str, ...]] = IndexRemappingBase._fields + ("indices",) class BasicIndex(IndexBase): @@ -1623,7 +1578,7 @@ class AdvancedIndexInNoncontiguousAxes(IndexBase): # {{{ base class for arguments -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class InputArgumentBase(Array): r"""Base class for input arguments. @@ -1663,7 +1618,7 @@ class DataInterface(Protocol): pass -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=False) class DataWrapper(InputArgumentBase): """Takes concrete array data and packages it to be compatible with the :class:`Array` interface. @@ -1706,8 +1661,6 @@ class DataWrapper(InputArgumentBase): data: DataInterface _shape: ShapeType - _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("data", - "shape") _mapper_method: ClassVar[str] = "map_data_wrapper" @property @@ -1736,7 +1689,7 @@ class DataWrapper(InputArgumentBase): # {{{ placeholder -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class Placeholder(_SuppliedShapeAndDtypeMixin, InputArgumentBase): r"""A named placeholder for an array whose concrete value is supplied by the user during evaluation. @@ -1749,12 +1702,6 @@ class Placeholder(_SuppliedShapeAndDtypeMixin, InputArgumentBase): .. automethod:: __init__ """ name: str - _shape: ShapeType - _dtype: np.dtype[Any] - - _fields: ClassVar[Tuple[str, ...]] = InputArgumentBase._fields + ("shape", - "dtype", - "name") _mapper_method: ClassVar[str] = "map_placeholder" @@ -1763,7 +1710,7 @@ class Placeholder(_SuppliedShapeAndDtypeMixin, InputArgumentBase): # {{{ size parameter -@attrs.define(frozen=True, eq=False, repr=False) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class SizeParam(InputArgumentBase): r"""A named placeholder for a scalar that may be used as a variable in symbolic expressions for array sizes. @@ -1777,7 +1724,6 @@ class SizeParam(InputArgumentBase): axes: AxesT = attrs.field(kw_only=True, default=()) _mapper_method: ClassVar[str] = "map_size_param" - _fields: ClassVar[Tuple[str, ...]] = InputArgumentBase._fields + ("name",) @property def shape(self) -> ShapeType: diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index 3dc5721a11050a45360de3c4923be32c2bad1916..7ff84ad81dbd72d89d12e92137018f610f5e6601 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -17,6 +17,16 @@ For completeness, individual (non-held/"stapled") :class:`DistributedSend` nodes can be made via this function: .. autofunction:: make_distributed_send + +Redirections for the documentation tool +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. currentmodule:: np + +.. class:: dtype + + See :class:`numpy.dtype`. + """ from __future__ import annotations @@ -45,7 +55,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Hashable, FrozenSet, Optional, Any, cast, ClassVar, Tuple +from typing import Hashable, FrozenSet, Optional, Any, ClassVar import attrs import numpy as np @@ -61,6 +71,7 @@ CommTagType = Hashable # {{{ send +@attrs.frozen(init=True, eq=True, hash=True, cache_hash=True) class DistributedSend(Taggable): """Class representing a distributed send operation. See :class:`DistributedSendRefHolder` for a way to ensure that nodes @@ -81,61 +92,23 @@ class DistributedSend(Taggable): receive the data being sent here. """ - def __init__(self, data: Array, dest_rank: int, comm_tag: CommTagType, - tags: FrozenSet[Tag] = frozenset()) -> None: - super().__init__(tags=tags) - self.data = data - self.dest_rank = dest_rank - self.comm_tag = comm_tag - - def __hash__(self) -> int: - return ( - hash(self.__class__) - ^ hash(self.data) - ^ hash(self.dest_rank) - ^ hash(self.comm_tag) - ^ hash(self.tags) - ) - - def __eq__(self, other: Any) -> bool: - return ( - self.__class__ is other.__class__ - and self.data == other.data - and self.dest_rank == other.dest_rank - and self.comm_tag == other.comm_tag - and self.tags == other.tags) + data: Array + dest_rank: int + comm_tag: CommTagType + tags: FrozenSet[Tag] = attrs.field(kw_only=True, default=frozenset()) def _with_new_tags(self, tags: FrozenSet[Tag]) -> DistributedSend: - return DistributedSend( - data=self.data, - dest_rank=self.dest_rank, - comm_tag=self.comm_tag, - tags=tags) + return attrs.evolve(self, tags=tags) def copy(self, **kwargs: Any) -> DistributedSend: - data: Optional[Array] = kwargs.get("data") - dest_rank: Optional[int] = kwargs.get("dest_rank") - comm_tag: Optional[CommTagType] = kwargs.get("comm_tag") - tags = cast(FrozenSet[Tag], kwargs.get("tags")) - return type(self)( - data=data if data is not None else self.data, - dest_rank=dest_rank if dest_rank is not None else self.dest_rank, - comm_tag=comm_tag if comm_tag is not None else self.comm_tag, - tags=tags if tags is not None else self.tags) - - def __repr__(self) -> str: - # self.data takes a lot of space, shorten it - return (f"DistributedSend(data={self.data.__class__} " - f"at {hex(id(self.data))}, " - f"dest_rank={self.dest_rank}, " - f"tags={self.tags}, comm_tag={self.comm_tag})") + return attrs.evolve(self, **kwargs) # }}} # {{{ send ref holder -@attrs.define(frozen=True, eq=False, repr=False, init=False) +@attrs.frozen(eq=False, repr=False, init=False, hash=True) class DistributedSendRefHolder(Array): """A node acting as an identity on :attr:`passthrough_data` while also holding a reference to a :class:`DistributedSend` in :attr:`send`. Since @@ -174,7 +147,6 @@ class DistributedSendRefHolder(Array): passthrough_data: Array _mapper_method: ClassVar[str] = "map_distributed_send_ref_holder" - _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("passthrough_data", "send") def __init__(self, send: DistributedSend, passthrough_data: Array, tags: FrozenSet[Tag] = frozenset()) -> None: @@ -209,7 +181,7 @@ class DistributedSendRefHolder(Array): # {{{ receive -@attrs.define(frozen=True, eq=False) +@attrs.frozen(eq=False, hash=True, cache_hash=True) class DistributedRecv(_SuppliedShapeAndDtypeMixin, Array): """Class representing a distributed receive operation. @@ -238,11 +210,7 @@ class DistributedRecv(_SuppliedShapeAndDtypeMixin, Array): """ src_rank: int comm_tag: CommTagType - _shape: ShapeType - _dtype: Any # FIXME: sphinx does not like `_dtype: _dtype_any` - _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("shape", "dtype", - "src_rank", "comm_tag") _mapper_method: ClassVar[str] = "map_distributed_recv" # }}} diff --git a/pytato/function.py b/pytato/function.py index 004aac18a3b5c188ca02dbd02ea30d9474fbac2f..4441c52afb885709b489f746e23f1b6c71126362 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -276,11 +276,11 @@ class Call(AbstractResultWithNamedArrays): copy = attrs.evolve - def __post_init__(self) -> None: + def __attrs_post_init__(self) -> None: # check that the invocation parameters and the function definition # parameters agree with each other. assert frozenset(self.bindings) == self.function.parameters - super().__post_init__() + super().__attrs_post_init__() def __contains__(self, name: object) -> bool: return name in self.function.returns diff --git a/pytato/loopy.py b/pytato/loopy.py index 11ac1bbe9ed55f82f0b26e72631ae5f420c52891..3d1ee157250dd952449488510bce5d7904fe89a9 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -33,11 +33,10 @@ from typing import (Dict, Optional, Any, Iterator, FrozenSet, Union, Sequence, Tuple, Iterable, Mapping, ClassVar) from numbers import Number from pytato.array import (AbstractResultWithNamedArrays, Array, ShapeType, - NamedArray, ArrayOrScalar, SizeParam, AxesT) + NamedArray, ArrayOrScalar, SizeParam) from pytato.scalar_expr import (SubstitutionMapper, ScalarExpression, EvaluationMapper, IntegralT) from pytools import memoize_method -from pytools.tag import Tag from immutables import Map import islpy as isl @@ -59,10 +58,20 @@ Internal stuff that is only here because the documentation tool wants it .. class:: AxesT See :class:`pytato.array.AxesT`. + +.. class:: ArrayOrScalar + + A :class:`~pytato.Array` or a scalar. + +.. currentmodule:: lp + +.. class:: TranslationUnit + + See :class:`loopy.TranslationUnit`. """ -@attrs.define(eq=False, frozen=True) +@attrs.frozen(eq=False) class LoopyCall(AbstractResultWithNamedArrays): """ An array expression node representing a call to an entrypoint in a @@ -106,11 +115,13 @@ class LoopyCall(AbstractResultWithNamedArrays): raise KeyError(name) # TODO: Attach a filtered set of tags from loopy's arg. - return LoopyCallResult(self, name, + return LoopyCallResult(container=self, + name=name, axes=_get_default_axes(len(self ._entry_kernel .arg_dict[name] - .shape))) + .shape)), + tags=frozenset()) def __len__(self) -> int: return len(self._result_names) @@ -119,35 +130,14 @@ class LoopyCall(AbstractResultWithNamedArrays): return iter(self._result_names) +@attrs.frozen(eq=False, hash=True, cache_hash=True) class LoopyCallResult(NamedArray): """ Named array for :class:`LoopyCall`'s result. Inherits from :class:`~pytato.array.NamedArray`. """ _mapper_method = "map_loopy_call_result" - - def __init__(self, - loopy_call: LoopyCall, - name: str, - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()) -> None: - super().__init__(loopy_call, name, axes=axes, tags=tags) - - # type-ignore reason: `copy` signature incompatible with super-class - def copy(self, *, # type: ignore[override] - loopy_call: Optional[AbstractResultWithNamedArrays] = None, - name: Optional[str] = None, - axes: Optional[AxesT] = None, - tags: Optional[FrozenSet[Tag]] = None) -> LoopyCallResult: - loopy_call = self._container if loopy_call is None else loopy_call - name = self.name if name is None else name - tags = self.tags if tags is None else tags - axes = self.axes if axes is None else axes - assert isinstance(loopy_call, LoopyCall) - return LoopyCallResult(loopy_call=loopy_call, - name=name, - axes=axes, - tags=tags) + _container: LoopyCall @property def expr(self) -> Array: diff --git a/pytato/stringifier.py b/pytato/stringifier.py index f411da225da166aa6c4f5ef4a1dddc3100756bd7..eca3a0d54109f373a75842a26ca712719440a745 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -32,6 +32,7 @@ from pytato.array import (Array, DataWrapper, DictOfNamedArrays, Axis, IndexLambda, ReductionDescriptor) from pytato.loopy import LoopyCall from immutables import Map +import attrs __doc__ = """ @@ -93,7 +94,7 @@ class Reprifier(Mapper): if depth > self.truncation_depth: return self.truncation_string - fields = expr._fields + fields = tuple(field.name for field in attrs.fields(type(expr))) if expr.ndim <= 1: # prettify: if ndim <=1 'expr.axes' would be trivial, @@ -150,8 +151,8 @@ class Reprifier(Mapper): return self.rec(getattr(expr, field), depth+1) return (f"{type(expr).__name__}(" - + ", ".join(f"{field}={_get_field_val(field)}" - for field in expr._fields) + + ", ".join(f"{field.name}={_get_field_val(field.name)}" + for field in attrs.fields(type(expr))) + ")") def map_loopy_call(self, expr: LoopyCall, depth: int) -> str: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index e192123f1bd439d416862ea8187bee8f03548205..593eee0627439e0e29dbb9b8faa1736cc7ce58cd 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -368,7 +368,7 @@ class CopyMapper(CachedMapper[ArrayOrNames]): rec_container = self.rec(expr._container) assert isinstance(rec_container, LoopyCall) return LoopyCallResult( - loopy_call=rec_container, + container=rec_container, name=expr.name, axes=expr.axes, tags=expr.tags) @@ -594,7 +594,7 @@ class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): rec_loopy_call = self.rec(expr._container, *args, **kwargs) assert isinstance(rec_loopy_call, LoopyCall) return LoopyCallResult( - loopy_call=rec_loopy_call, + container=rec_loopy_call, name=expr.name, axes=expr.axes, tags=expr.tags) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 0008cea20859c0c2accdd260e022863e2055eb59..c4ade1a69827b281ca3f054a554cc20c83b2e77a 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -184,24 +184,24 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): # Default handler, does its best to guess how to handle fields. info = self.get_common_dot_info(expr) - for field in expr._fields: - if field in info.fields: + for field in attrs.fields(type(expr)): + if field.name in info.fields: continue - attr = getattr(expr, field) + attr = getattr(expr, field.name) if isinstance(attr, Array): self.rec(attr) - info.edges[field] = attr + info.edges[field.name] = attr elif isinstance(attr, AbstractResultWithNamedArrays): self.rec(attr) - info.edges[field] = attr + info.edges[field.name] = attr elif isinstance(attr, tuple): - info.fields[field] = stringify_shape(attr) + info.fields[field.name] = stringify_shape(attr) else: - info.fields[field] = str(attr) + info.fields[field.name] = str(attr) self.node_to_dot[expr] = info diff --git a/test/test_pytato.py b/test/test_pytato.py index 5b35bdde7d9d374c3863116c26ff023dfb336de2..4ace4f3ed60664fb5c115bb35bafb16cd6732b8e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -459,25 +459,27 @@ def test_array_dot_repr(): 3*x + 4*y, """ IndexLambda( + shape=(10, 4), + dtype='int64', expr=Sum((Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), Subscript(Variable('_in1'), (Variable('_0'), Variable('_1'))))), - shape=(10, 4), - dtype='int64', - bindings={'_in0': IndexLambda(expr=Product((3, Subscript(Variable('_in1'), - (Variable('_0'), - Variable('_1'))))), + bindings={'_in0': IndexLambda( shape=(10, 4), dtype='int64', + expr=Product((3, Subscript(Variable('_in1'), + (Variable('_0'), + Variable('_1'))))), bindings={'_in1': Placeholder(shape=(10, 4), dtype='int64', name='x')}), - '_in1': IndexLambda(expr=Product((4, Subscript(Variable('_in1'), - (Variable('_0'), - Variable('_1'))))), + '_in1': IndexLambda( shape=(10, 4), dtype='int64', + expr=Product((4, Subscript(Variable('_in1'), + (Variable('_0'), + Variable('_1'))))), bindings={'_in1': Placeholder(shape=(10, 4), dtype='int64', name='y')})})""") @@ -497,20 +499,20 @@ Roll( _assert_stripped_repr(y * pt.not_equal(x, 3), """ IndexLambda( + shape=(10, 4), + dtype='int64', expr=Product((Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), Subscript(Variable('_in1'), (Variable('_0'), Variable('_1'))))), - shape=(10, 4), - dtype='int64', bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='y'), '_in1': IndexLambda( + shape=(10, 4), + dtype='bool', expr=Comparison(Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), '!=', 3), - shape=(10, 4), - dtype='bool', bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='x')})})""")