From c9a4aad9cd4c34acaa295f3d337b38c59999956b Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Wed, 3 Jun 2020 01:26:15 -0500 Subject: [PATCH] Fix array and array_expr --- pytato/array.py | 59 +++++++------------------------------------- pytato/array_expr.py | 4 +-- setup.cfg | 3 +++ 3 files changed, 14 insertions(+), 52 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 9027342..15237ce 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -86,21 +86,20 @@ Node constructors such as :class:`Placeholder.__init__` and # }}} -import collections from functools import partialmethod from numbers import Number import operator from dataclasses import dataclass -from typing import Optional, Dict, Any, MutableMapping, Mapping, Iterator, Tuple, Union, FrozenSet +from typing import Optional, Dict, Any, Mapping, Iterator, Tuple, Union, FrozenSet import numpy as np -import pymbolic.mapper import pymbolic.primitives as prim from pytools import is_single_valued, memoize_method import pytato.scalar_expr as scalar_expr from pytato.scalar_expr import ScalarExpression + # {{{ dotted name class DottedName: @@ -144,36 +143,6 @@ class DottedName: # {{{ namespace -class _NamespaceCopyMapper(scalar_expr.IdentityMapper): - - def __call__(self, expr: Array, namespace: Namespace, cache: Dict[Array, Array]) -> Array: - return self.rec(expr, namespace, cache) - - def rec(self, expr: Array, namespace: Namespace, cache: Dict[Array, Array]) -> Array: - if expr in cache: - return cache[expr] - result: Array = super().rec(expr, namespace, cache) - cache[expr] = result - return result - - def map_index_lambda(self, expr: IndexLambda, namespace: Namespace, cache: Dict[Array, Array]) -> Array: - bindings = { - name: self.rec(subexpr, namespace, cache) - for name, subexpr in expr.bindings.items()} - return IndexLambda( - namespace, - expr=expr.expr, - shape=expr.shape, - dtype=expr.dtype, - bindings=bindings) - - def map_placeholder(self, expr: Placeholder, namespace: Namespace, cache: Dict[Array, Array]) -> Array: - return Placeholder(namespace, expr.name, expr.shape, expr.dtype, expr.tags) - - def map_output(self, expr: Output, namespace: Namespace, cache: Dict[Array, Array]) -> Array: - return Output(namespace, expr.name, self.rec(expr.array, namespace, cache), expr.tags) - - class Namespace(Mapping[str, "Array"]): # Possible future extension: .parent attribute r""" @@ -191,10 +160,8 @@ class Namespace(Mapping[str, "Array"]): .. automethod:: ref """ - def __init__(self, _symbol_table: Optional[MutableMapping[str, Array]] = None) -> None: - if _symbol_table is None: - _symbol_table = {} - self._symbol_table: MutableMapping[str, Array] = _symbol_table + def __init__(self) -> None: + self._symbol_table: Dict[str, Array] = {} def __contains__(self, name: object) -> bool: return name in self._symbol_table @@ -211,18 +178,9 @@ class Namespace(Mapping[str, "Array"]): def __len__(self) -> int: return len(self._symbol_table) - def _chain(self) -> Namespace: - return Namespace(collections.ChainMap(dict(), self._symbol_table)) - def copy(self) -> Namespace: - result = Namespace() - mapper = _NamespaceCopyMapper() - cache: Dict[Array, Array] = {} - for name in self: - val = mapper(self[name], result, cache) - if name not in result: - result.assign(name, val) - return result + from pytato.array_expr import CopyMapper, copy_namespace + return copy_namespace(self, CopyMapper(Namespace())) def assign(self, name: str, value: Array) -> str: """Declare a new array. @@ -429,7 +387,8 @@ class Array: """ - def __init__(self, namespace: Namespace, shape: ShapeType, dtype: np.dtype, tags: Optional[TagsType] = None): + def __init__(self, namespace: Namespace, shape: ShapeType, dtype: np.dtype, + tags: Optional[TagsType] = None): if tags is None: tags = frozenset() @@ -794,7 +753,7 @@ class _ArgLike(Array): if self is other: return True # Uniquely identified by name. - return ( + return ( isinstance(other, _ArgLike) and self.namespace is other.namespace and self.name == other.name) diff --git a/pytato/array_expr.py b/pytato/array_expr.py index 4c68081..fdd8202 100644 --- a/pytato/array_expr.py +++ b/pytato/array_expr.py @@ -91,8 +91,8 @@ class CopyMapper(Mapper): def copy_namespace(namespace: Namespace, copy_mapper: CopyMapper) -> Namespace: """Copy the elements of *namespace* into a new namespace. - :param namespace: The original namespace - :param mapper: A mapper that performs copies into a new namespace + :param namespace: The source namespace + :param copy_mapper: A mapper that performs copies into a new namespace :returns: The new namespace """ for name, val in namespace.items(): diff --git a/setup.cfg b/setup.cfg index 7b0c351..7c3f3d4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,6 +2,9 @@ ignore = E126,E127,E128,E123,E226,E241,E242,E265,N802,W503,E402,N814,N817,W504 max-line-length=85 +[mypy-pytato.array_expr] +disallow_subclassing_any = False + [mypy-pytato.scalar_expr] disallow_subclassing_any = False -- GitLab