From ffb9928f8b6c0939ad355ded01076ae7fed64319 Mon Sep 17 00:00:00 2001
From: Nicholas Christensen <njchris2@illinois.edu>
Date: Tue, 24 Nov 2020 06:06:06 -0600
Subject: [PATCH] use Taggable in pytools

---
 pytato/array.py | 31 ++++++++++++-------------------
 setup.py        |  2 +-
 2 files changed, 13 insertions(+), 20 deletions(-)

diff --git a/pytato/array.py b/pytato/array.py
index cae6263..334467f 100644
--- a/pytato/array.py
+++ b/pytato/array.py
@@ -152,7 +152,8 @@ import numpy as np
 import pymbolic.primitives as prim
 from pymbolic import var
 from pytools import is_single_valued, memoize_method, UniqueNameGenerator
-from pytools.tag import Tag, UniqueTag, TagsType, tag_dataclass
+from pytools.tag import Tag, Taggable, UniqueTag, TagsType, TagOrTagsType,
+        tag_dataclass
 
 import pytato.scalar_expr as scalar_expr
 from pytato.scalar_expr import ScalarExpression, IntegralScalarExpression
@@ -329,7 +330,7 @@ def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype:
         return dtype
 
 
-class Array:
+class Array(Taggable):
     """
     A base class (abstract interface + supplemental functionality) for lazily
     evaluating array expressions. The interface seeks to maximize :mod:`numpy`
@@ -410,11 +411,8 @@ class Array:
     # hashable. Dicts of hashable keys and values are also permitted.
     _fields: ClassVar[Tuple[str, ...]] = ("shape", "dtype", "tags")
 
-    def __init__(self, tags: Optional[TagsType] = None):
-        if tags is None:
-            tags = frozenset()
-
-        self.tags = tags
+    def __init__(self, tags: TagOrTagsType = frozenset()):
+        super(self, tags=tags)
 
     def copy(self, **kwargs: Any) -> Array:
         raise NotImplementedError
@@ -530,22 +528,17 @@ class Array:
     def T(self) -> Array:
         return AxisPermutation(self, tuple(range(self.ndim)[::-1]))
 
-    def tagged(self, tag: Tag) -> Array:
+    def tagged(self, tags: TagOrTagsType) -> Array:
         """
-        Returns a copy of *self* tagged with *tag*.
-        If *tag* is a :class:`pytools.tag.UniqueTag` and other
+        Returns a copy of *self* tagged with *tags*.
+        If *tags* is/has a :class:`pytools.tag.UniqueTag` and other
         tags of this type are already present, an error
         is raised.
         """
-        return self.copy(tags=self.tags | frozenset([tag]))
-
-    def without_tag(self, tag: Tag, verify_existence: bool = True) -> Array:
-        new_tags = tuple(t for t in self.tags if t != tag)
-
-        if verify_existence and len(new_tags) == len(self.tags):
-            raise ValueError(f"tag '{tag}' was not present")
+        return super.tagged(tags)
 
-        return self.copy(tags=new_tags)
+    def without_tags(self, tags: TagOrTagsType, verify_existence: bool = True) -> Array:
+        return super.without_tags(tags, verify_existence=verify_existence)
 
     @memoize_method
     def __hash__(self) -> int:
@@ -1234,7 +1227,7 @@ class InputArgumentBase(Array):
     def tagged(self, tag: Tag) -> Array:
         raise ValueError("Cannot modify tags")
 
-    def without_tag(self, tag: Tag, verify_existence: bool = True) -> Array:
+    def without_tags(self, tag: Tag, verify_existence: bool = True) -> Array:
         raise ValueError("Cannot modify tags")
 
 # }}}
diff --git a/setup.py b/setup.py
index c5a2e6a..86db32e 100644
--- a/setup.py
+++ b/setup.py
@@ -35,7 +35,7 @@ setup(name="pytato",
       python_requires="~=3.8",
       install_requires=[
           "loopy>=2020.2",
-          "pytools>=2020.4.2"
+          "pytools>=2020.4.4"
           ],
 
       author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei",
-- 
GitLab