From ffb9928f8b6c0939ad355ded01076ae7fed64319 Mon Sep 17 00:00:00 2001 From: Nicholas Christensen <njchris2@illinois.edu> Date: Tue, 24 Nov 2020 06:06:06 -0600 Subject: [PATCH] use Taggable in pytools --- pytato/array.py | 31 ++++++++++++------------------- setup.py | 2 +- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index cae6263..334467f 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -152,7 +152,8 @@ import numpy as np import pymbolic.primitives as prim from pymbolic import var from pytools import is_single_valued, memoize_method, UniqueNameGenerator -from pytools.tag import Tag, UniqueTag, TagsType, tag_dataclass +from pytools.tag import Tag, Taggable, UniqueTag, TagsType, TagOrTagsType, + tag_dataclass import pytato.scalar_expr as scalar_expr from pytato.scalar_expr import ScalarExpression, IntegralScalarExpression @@ -329,7 +330,7 @@ def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype: return dtype -class Array: +class Array(Taggable): """ A base class (abstract interface + supplemental functionality) for lazily evaluating array expressions. The interface seeks to maximize :mod:`numpy` @@ -410,11 +411,8 @@ class Array: # hashable. Dicts of hashable keys and values are also permitted. _fields: ClassVar[Tuple[str, ...]] = ("shape", "dtype", "tags") - def __init__(self, tags: Optional[TagsType] = None): - if tags is None: - tags = frozenset() - - self.tags = tags + def __init__(self, tags: TagOrTagsType = frozenset()): + super(self, tags=tags) def copy(self, **kwargs: Any) -> Array: raise NotImplementedError @@ -530,22 +528,17 @@ class Array: def T(self) -> Array: return AxisPermutation(self, tuple(range(self.ndim)[::-1])) - def tagged(self, tag: Tag) -> Array: + def tagged(self, tags: TagOrTagsType) -> Array: """ - Returns a copy of *self* tagged with *tag*. - If *tag* is a :class:`pytools.tag.UniqueTag` and other + Returns a copy of *self* tagged with *tags*. + If *tags* is/has a :class:`pytools.tag.UniqueTag` and other tags of this type are already present, an error is raised. """ - return self.copy(tags=self.tags | frozenset([tag])) - - def without_tag(self, tag: Tag, verify_existence: bool = True) -> Array: - new_tags = tuple(t for t in self.tags if t != tag) - - if verify_existence and len(new_tags) == len(self.tags): - raise ValueError(f"tag '{tag}' was not present") + return super.tagged(tags) - return self.copy(tags=new_tags) + def without_tags(self, tags: TagOrTagsType, verify_existence: bool = True) -> Array: + return super.without_tags(tags, verify_existence=verify_existence) @memoize_method def __hash__(self) -> int: @@ -1234,7 +1227,7 @@ class InputArgumentBase(Array): def tagged(self, tag: Tag) -> Array: raise ValueError("Cannot modify tags") - def without_tag(self, tag: Tag, verify_existence: bool = True) -> Array: + def without_tags(self, tag: Tag, verify_existence: bool = True) -> Array: raise ValueError("Cannot modify tags") # }}} diff --git a/setup.py b/setup.py index c5a2e6a..86db32e 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ setup(name="pytato", python_requires="~=3.8", install_requires=[ "loopy>=2020.2", - "pytools>=2020.4.2" + "pytools>=2020.4.4" ], author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei", -- GitLab