From c9a4aad9cd4c34acaa295f3d337b38c59999956b Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Wed, 3 Jun 2020 01:26:15 -0500
Subject: [PATCH] Fix array and array_expr

---
 pytato/array.py      | 59 +++++++-------------------------------------
 pytato/array_expr.py |  4 +--
 setup.cfg            |  3 +++
 3 files changed, 14 insertions(+), 52 deletions(-)

diff --git a/pytato/array.py b/pytato/array.py
index 9027342..15237ce 100644
--- a/pytato/array.py
+++ b/pytato/array.py
@@ -86,21 +86,20 @@ Node constructors such as :class:`Placeholder.__init__` and
 
 # }}}
 
-import collections
 from functools import partialmethod
 from numbers import Number
 import operator
 from dataclasses import dataclass
-from typing import Optional, Dict, Any, MutableMapping, Mapping, Iterator, Tuple, Union, FrozenSet
+from typing import Optional, Dict, Any, Mapping, Iterator, Tuple, Union, FrozenSet
 
 import numpy as np
-import pymbolic.mapper
 import pymbolic.primitives as prim
 from pytools import is_single_valued, memoize_method
 
 import pytato.scalar_expr as scalar_expr
 from pytato.scalar_expr import ScalarExpression
 
+
 # {{{ dotted name
 
 class DottedName:
@@ -144,36 +143,6 @@ class DottedName:
 
 # {{{ namespace
 
-class _NamespaceCopyMapper(scalar_expr.IdentityMapper):
-
-    def __call__(self, expr: Array, namespace: Namespace, cache: Dict[Array, Array]) -> Array:
-        return self.rec(expr, namespace, cache)
-
-    def rec(self, expr: Array, namespace: Namespace, cache: Dict[Array, Array]) -> Array:
-        if expr in cache:
-            return cache[expr]
-        result: Array = super().rec(expr, namespace, cache)
-        cache[expr] = result
-        return result
-
-    def map_index_lambda(self, expr: IndexLambda, namespace: Namespace, cache: Dict[Array, Array]) -> Array:
-        bindings = {
-                name: self.rec(subexpr, namespace, cache)
-                for name, subexpr in expr.bindings.items()}
-        return IndexLambda(
-                namespace,
-                expr=expr.expr,
-                shape=expr.shape,
-                dtype=expr.dtype,
-                bindings=bindings)
-
-    def map_placeholder(self, expr: Placeholder, namespace: Namespace, cache: Dict[Array, Array]) -> Array:
-        return Placeholder(namespace, expr.name, expr.shape, expr.dtype, expr.tags)
-
-    def map_output(self, expr: Output, namespace: Namespace, cache: Dict[Array, Array]) -> Array:
-        return Output(namespace, expr.name, self.rec(expr.array, namespace, cache), expr.tags)
-
-
 class Namespace(Mapping[str, "Array"]):
     # Possible future extension: .parent attribute
     r"""
@@ -191,10 +160,8 @@ class Namespace(Mapping[str, "Array"]):
     .. automethod:: ref
     """
 
-    def __init__(self, _symbol_table: Optional[MutableMapping[str, Array]] = None) -> None:
-        if _symbol_table is None:
-            _symbol_table = {}
-        self._symbol_table: MutableMapping[str, Array] = _symbol_table
+    def __init__(self) -> None:
+        self._symbol_table: Dict[str, Array] = {}
 
     def __contains__(self, name: object) -> bool:
         return name in self._symbol_table
@@ -211,18 +178,9 @@ class Namespace(Mapping[str, "Array"]):
     def __len__(self) -> int:
         return len(self._symbol_table)
 
-    def _chain(self) -> Namespace:
-        return Namespace(collections.ChainMap(dict(), self._symbol_table))
-
     def copy(self) -> Namespace:
-        result = Namespace()
-        mapper = _NamespaceCopyMapper()
-        cache: Dict[Array, Array] = {}
-        for name in self:
-            val = mapper(self[name], result, cache)
-            if name not in result:
-                result.assign(name, val)
-        return result
+        from pytato.array_expr import CopyMapper, copy_namespace
+        return copy_namespace(self, CopyMapper(Namespace()))
 
     def assign(self, name: str, value: Array) -> str:
         """Declare a new array.
@@ -429,7 +387,8 @@ class Array:
 
     """
 
-    def __init__(self, namespace: Namespace, shape: ShapeType, dtype: np.dtype, tags: Optional[TagsType] = None):
+    def __init__(self, namespace: Namespace, shape: ShapeType, dtype: np.dtype,
+            tags: Optional[TagsType] = None):
         if tags is None:
             tags = frozenset()
 
@@ -794,7 +753,7 @@ class _ArgLike(Array):
         if self is other:
             return True
         # Uniquely identified by name.
-        return  (
+        return (
                 isinstance(other, _ArgLike)
                 and self.namespace is other.namespace
                 and self.name == other.name)
diff --git a/pytato/array_expr.py b/pytato/array_expr.py
index 4c68081..fdd8202 100644
--- a/pytato/array_expr.py
+++ b/pytato/array_expr.py
@@ -91,8 +91,8 @@ class CopyMapper(Mapper):
 def copy_namespace(namespace: Namespace, copy_mapper: CopyMapper) -> Namespace:
     """Copy the elements of *namespace* into a new namespace.
 
-    :param namespace: The original namespace
-    :param mapper: A mapper that performs copies into a new namespace
+    :param namespace: The source namespace
+    :param copy_mapper: A mapper that performs copies into a new namespace
     :returns: The new namespace
     """
     for name, val in namespace.items():
diff --git a/setup.cfg b/setup.cfg
index 7b0c351..7c3f3d4 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -2,6 +2,9 @@
 ignore = E126,E127,E128,E123,E226,E241,E242,E265,N802,W503,E402,N814,N817,W504
 max-line-length=85
 
+[mypy-pytato.array_expr]
+disallow_subclassing_any = False
+
 [mypy-pytato.scalar_expr]
 disallow_subclassing_any = False
 
-- 
GitLab