diff --git a/pytato/array.py b/pytato/array.py index 42041060ff83432a4867c81172b529326741005f..9027342c6b9a58c17b44f93edd88356357dec4bd 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -67,6 +67,7 @@ Built-in Expression Nodes .. autoclass:: Reshape .. autoclass:: DataWrapper .. autoclass:: Placeholder +.. autoclass:: Output .. autoclass:: LoopyFunction User-Facing Node Creation @@ -85,16 +86,21 @@ Node constructors such as :class:`Placeholder.__init__` and # }}} +import collections +from functools import partialmethod +from numbers import Number +import operator +from dataclasses import dataclass +from typing import Optional, Dict, Any, MutableMapping, Mapping, Iterator, Tuple, Union, FrozenSet + import numpy as np +import pymbolic.mapper import pymbolic.primitives as prim +from pytools import is_single_valued, memoize_method + import pytato.scalar_expr as scalar_expr from pytato.scalar_expr import ScalarExpression -from dataclasses import dataclass -from pytools import is_single_valued -from typing import Optional, Dict, Any, Mapping, Iterator, Tuple, Union, FrozenSet - - # {{{ dotted name class DottedName: @@ -138,7 +144,37 @@ class DottedName: # {{{ namespace -class Namespace: +class _NamespaceCopyMapper(scalar_expr.IdentityMapper): + + def __call__(self, expr: Array, namespace: Namespace, cache: Dict[Array, Array]) -> Array: + return self.rec(expr, namespace, cache) + + def rec(self, expr: Array, namespace: Namespace, cache: Dict[Array, Array]) -> Array: + if expr in cache: + return cache[expr] + result: Array = super().rec(expr, namespace, cache) + cache[expr] = result + return result + + def map_index_lambda(self, expr: IndexLambda, namespace: Namespace, cache: Dict[Array, Array]) -> Array: + bindings = { + name: self.rec(subexpr, namespace, cache) + for name, subexpr in expr.bindings.items()} + return IndexLambda( + namespace, + expr=expr.expr, + shape=expr.shape, + dtype=expr.dtype, + bindings=bindings) + + def map_placeholder(self, expr: Placeholder, namespace: Namespace, cache: Dict[Array, Array]) -> Array: + return Placeholder(namespace, expr.name, expr.shape, expr.dtype, expr.tags) + + def map_output(self, expr: Output, namespace: Namespace, cache: Dict[Array, Array]) -> Array: + return Output(namespace, expr.name, self.rec(expr.array, namespace, cache), expr.tags) + + +class Namespace(Mapping[str, "Array"]): # Possible future extension: .parent attribute r""" Represents a mapping from :term:`identifier` strings to @@ -149,14 +185,18 @@ class Namespace: .. automethod:: __contains__ .. automethod:: __getitem__ .. automethod:: __iter__ + .. automethod:: __len__ .. automethod:: assign + .. automethod:: copy .. automethod:: ref """ - def __init__(self) -> None: - self._symbol_table: Dict[str, Optional[Array]] = {} + def __init__(self, _symbol_table: Optional[MutableMapping[str, Array]] = None) -> None: + if _symbol_table is None: + _symbol_table = {} + self._symbol_table: MutableMapping[str, Array] = _symbol_table - def __contains__(self, name: str) -> bool: + def __contains__(self, name: object) -> bool: return name in self._symbol_table def __getitem__(self, name: str) -> Array: @@ -168,13 +208,27 @@ class Namespace: def __iter__(self) -> Iterator[str]: return iter(self._symbol_table) - def assign(self, name: str, - value: Optional[Array]) -> str: + def __len__(self) -> int: + return len(self._symbol_table) + + def _chain(self) -> Namespace: + return Namespace(collections.ChainMap(dict(), self._symbol_table)) + + def copy(self) -> Namespace: + result = Namespace() + mapper = _NamespaceCopyMapper() + cache: Dict[Array, Array] = {} + for name in self: + val = mapper(self[name], result, cache) + if name not in result: + result.assign(name, val) + return result + + def assign(self, name: str, value: Array) -> str: """Declare a new array. :param name: a Python identifier - :param value: the array object, or None if the assignment is to - just reserve a name + :param value: the array object :returns: *name* """ @@ -239,6 +293,7 @@ class UniqueTag(Tag): Only one instance of this type of tag may be assigned to a single tagged object. """ + pass TagsType = FrozenSet[Tag] @@ -285,12 +340,10 @@ def normalize_shape( :param ns: if a namespace is given, extra checks are performed to ensure that expressions are well-defined. """ - from pytato.scalar_expr import parse - def normalize_shape_component( s: ConvertibleToShapeComponent) -> ScalarExpression: if isinstance(s, str): - s = parse(s) + s = scalar_expr.parse(s) if isinstance(s, int): if s < 0: @@ -303,7 +356,7 @@ def normalize_shape( return s if isinstance(shape, str): - shape = parse(shape) + shape = scalar_expr.parse(shape) from numbers import Number if isinstance(shape, (Number, prim.Expression)): @@ -376,21 +429,25 @@ class Array: """ - def __init__(self, namespace: Namespace, - tags: Optional[TagsType] = None): + def __init__(self, namespace: Namespace, shape: ShapeType, dtype: np.dtype, tags: Optional[TagsType] = None): if tags is None: tags = frozenset() self.namespace = namespace self.tags = tags - self.dtype: np.dtype = np.float64 # FIXME + self._shape = shape + self._dtype = dtype def copy(self, **kwargs: Any) -> Array: raise NotImplementedError @property def shape(self) -> ShapeType: - raise NotImplementedError + return self._shape + + @property + def dtype(self) -> np.dtype: + return self._dtype def named(self, name: str) -> Array: return self.namespace.ref(self.namespace.assign(name, self)) @@ -418,8 +475,56 @@ class Array: return self.copy(tags=new_tags) - # TODO: - # - codegen interface + @memoize_method + def __hash__(self) -> int: + raise NotImplementedError + + def __eq__(self, other: Any) -> bool: + raise NotImplementedError + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def _join_dtypes(self, *args: np.dtype) -> np.dtype: + result = args[0] + for arg in args[1:]: + result = (np.empty(0, dtype=result) + np.empty(0, dtype=arg)).dtype + return result + + def _binary_op(self, op: Any, + other: Union[Array, Number], + reverse: bool = False) -> Array: + if isinstance(other, Number): + # TODO + raise NotImplementedError + else: + if self.shape != other.shape: + raise ValueError("shapes do not match for binary operator") + + dtype = self._join_dtypes(self.dtype, other.dtype) + + # FIXME: If either *self* or *other* is an IndexLambda, its expression + # could be folded into the output, producing a fused result. + if self.shape == (): + expr = op(prim.Variable("_in0"), prim.Variable("_in1")) + else: + indices = tuple(prim.Variable(f"_{i}") for i in range(self.ndim)) + expr = op( + prim.Variable("_in0")[indices], + prim.Variable("_in1")[indices]) + + first, second = self, other + if reverse: + first, second = second, first + + bindings = dict(_in0=first, _in1=second) + + return IndexLambda( + self.namespace, expr, + shape=self.shape, dtype=dtype, bindings=bindings) + + __mul__ = partialmethod(_binary_op, operator.mul) + __rmul__ = partialmethod(_binary_op, operator.mul, reverse=True) # }}} @@ -538,32 +643,26 @@ class IndexLambda(Array): .. automethod:: is_reference """ - # TODO: write make_index_lambda() that does dtype inference + mapper_method = "map_index_lambda" def __init__( - self, namespace: Namespace, expr: prim.Expression, - shape: ShapeType, dtype: np.dtype, + self, + namespace: Namespace, + expr: prim.Expression, + shape: ShapeType, + dtype: np.dtype, bindings: Optional[Dict[str, Array]] = None, tags: Optional[TagsType] = None): if bindings is None: bindings = {} - super().__init__(namespace, tags=tags) + super().__init__(namespace, shape=shape, dtype=dtype, tags=tags) - self._shape = shape - self._dtype = dtype self.expr = expr self.bindings = bindings - @property - def shape(self) -> ShapeType: - return self._shape - - @property - def dtype(self) -> np.dtype: - return self._dtype - + @memoize_method def is_reference(self) -> bool: # FIXME: Do we want a specific 'reference' node to make all this # checking unnecessary? @@ -594,6 +693,28 @@ class IndexLambda(Array): return True + @memoize_method + def __hash__(self) -> int: + return hash(( + self.expr, + self.shape, + self.dtype, + frozenset(self.bindings.items()), + self.tags)) + + def __eq__(self, other: object) -> bool: + if self is other: + return True + + return ( + isinstance(other, IndexLambda) + and self.namespace is other.namespace + and self.expr == other.expr + and self.shape == other.shape + and self.dtype == other.dtype + and self.bindings == other.bindings + and self.tags == other.tags) + # }}} @@ -638,24 +759,54 @@ class DataWrapper(Array): # TODO: not really Any data def __init__(self, namespace: Namespace, data: Any, tags: Optional[TagsType] = None): - super().__init__(namespace, tags) - + super().__init__(namespace, shape=data.shape, dtype=data.dtype, tags=tags) self.data = data - @property - def shape(self) -> Any: # FIXME - return self.data.shape - - @property - def dtype(self) -> np.dtype: - return self.data.dtype - # }}} # {{{ placeholder -class Placeholder(Array): +class _ArgLike(Array): + + def __init__(self, + namespace: Namespace, + name: str, + shape: ShapeType, + dtype: np.dtype, + tags: Optional[TagsType] = None): + if name is None: + raise ValueError("Must have explicit name") + + # Reserve the name, prevent others from using it. + namespace.assign(name, self) + + super().__init__( + namespace=namespace, shape=shape, dtype=dtype, tags=tags) + + self.name = name + + @memoize_method + def __hash__(self) -> int: + return hash((self.name,)) + + def __eq__(self, other: object) -> bool: + if self is other: + return True + # Uniquely identified by name. + return ( + isinstance(other, _ArgLike) + and self.namespace is other.namespace + and self.name == other.name) + + def tagged(self, tag: Tag) -> Array: + raise ValueError("Cannot modify tags") + + def without_tag(self, tag: Tag, verify_existence: bool = True) -> Array: + raise ValueError("Cannot modify tags") + + +class Placeholder(_ArgLike): """ A named placeholder for an array whose concrete value is supplied by the user during evaluation. @@ -667,29 +818,46 @@ class Placeholder(Array): .. note:: - :attr:`name` is not a :term:`namespace name`. In fact, - it is prohibited from being one. (This has to be the case: Suppose a - :class:`Placeholder` is :meth:`~Array.tagged`, would the namespace name - refer to the tagged or the untagged version?) + Modifying :class:`Placeholder` tags is not supported after + creation. """ - def __init__(self, namespace: Namespace, - name: str, shape: ShapeType, - tags: Optional[TagsType] = None): + mapper_method = "map_placeholder" - # Reserve the name, prevent others from using it. - namespace.assign(name, None) +# }}} - super().__init__(namespace=namespace, tags=tags) - self.name = name - self._shape = shape +# {{{ output - @property - def shape(self) -> ShapeType: - # Matt added this to make Pylint happy. - # Not tied to this, open for discussion about how to implement this. - return self._shape +class Output(_ArgLike): + """A named output of the computation. + + .. attribute:: name + + The name of the output array. + + .. attribute:: array + + The :class:`Array` value that is output. + + .. note:: + + Modifying :class:`Output` tags is not supported after creation. + """ + + mapper_method = "map_output" + + def __init__(self, + namespace: Namespace, + name: str, + array: Array, + tags: Optional[TagsType] = None): + super().__init__(namespace=namespace, + name=name, + shape=array.shape, + dtype=array.dtype, + tags=tags) + self.array = array # }}} @@ -726,20 +894,26 @@ def make_dict_of_named_arrays( def make_placeholder(namespace: Namespace, name: str, shape: ConvertibleToShape, + dtype: np.dtype, tags: Optional[TagsType] = None ) -> Placeholder: """Make a :class:`Placeholder` object. - :param namespace: namespace of the placeholder array - :param shape: shape of the placeholder array - :param tags: implementation tags + :param namespace: namespace of the placeholder array + :param name: name of the placeholder array + :param shape: shape of the placeholder array + :param dtype: dtype of the placeholder array + :param tags: implementation tags """ + if name is None: + raise ValueError("Placeholder instances must have a name") + if not str.isidentifier(name): - raise ValueError(f"{name} is not a Python identifier") + raise ValueError(f"'{name}' is not a Python identifier") shape = normalize_shape(shape, namespace) - return Placeholder(namespace, name, shape, tags) + return Placeholder(namespace, name, shape, dtype, tags) # }}}