Skip to content
Snippets Groups Projects
Commit ffb9928f authored by Nicholas Christensen's avatar Nicholas Christensen
Browse files

use Taggable in pytools

parent b1d614b0
No related branches found
No related tags found
No related merge requests found
...@@ -152,7 +152,8 @@ import numpy as np ...@@ -152,7 +152,8 @@ import numpy as np
import pymbolic.primitives as prim import pymbolic.primitives as prim
from pymbolic import var from pymbolic import var
from pytools import is_single_valued, memoize_method, UniqueNameGenerator from pytools import is_single_valued, memoize_method, UniqueNameGenerator
from pytools.tag import Tag, UniqueTag, TagsType, tag_dataclass from pytools.tag import Tag, Taggable, UniqueTag, TagsType, TagOrTagsType,
tag_dataclass
import pytato.scalar_expr as scalar_expr import pytato.scalar_expr as scalar_expr
from pytato.scalar_expr import ScalarExpression, IntegralScalarExpression from pytato.scalar_expr import ScalarExpression, IntegralScalarExpression
...@@ -329,7 +330,7 @@ def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype: ...@@ -329,7 +330,7 @@ def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype:
return dtype return dtype
class Array: class Array(Taggable):
""" """
A base class (abstract interface + supplemental functionality) for lazily A base class (abstract interface + supplemental functionality) for lazily
evaluating array expressions. The interface seeks to maximize :mod:`numpy` evaluating array expressions. The interface seeks to maximize :mod:`numpy`
...@@ -410,11 +411,8 @@ class Array: ...@@ -410,11 +411,8 @@ class Array:
# hashable. Dicts of hashable keys and values are also permitted. # hashable. Dicts of hashable keys and values are also permitted.
_fields: ClassVar[Tuple[str, ...]] = ("shape", "dtype", "tags") _fields: ClassVar[Tuple[str, ...]] = ("shape", "dtype", "tags")
def __init__(self, tags: Optional[TagsType] = None): def __init__(self, tags: TagOrTagsType = frozenset()):
if tags is None: super(self, tags=tags)
tags = frozenset()
self.tags = tags
def copy(self, **kwargs: Any) -> Array: def copy(self, **kwargs: Any) -> Array:
raise NotImplementedError raise NotImplementedError
...@@ -530,22 +528,17 @@ class Array: ...@@ -530,22 +528,17 @@ class Array:
def T(self) -> Array: def T(self) -> Array:
return AxisPermutation(self, tuple(range(self.ndim)[::-1])) return AxisPermutation(self, tuple(range(self.ndim)[::-1]))
def tagged(self, tag: Tag) -> Array: def tagged(self, tags: TagOrTagsType) -> Array:
""" """
Returns a copy of *self* tagged with *tag*. Returns a copy of *self* tagged with *tags*.
If *tag* is a :class:`pytools.tag.UniqueTag` and other If *tags* is/has a :class:`pytools.tag.UniqueTag` and other
tags of this type are already present, an error tags of this type are already present, an error
is raised. is raised.
""" """
return self.copy(tags=self.tags | frozenset([tag])) return super.tagged(tags)
def without_tag(self, tag: Tag, verify_existence: bool = True) -> Array:
new_tags = tuple(t for t in self.tags if t != tag)
if verify_existence and len(new_tags) == len(self.tags):
raise ValueError(f"tag '{tag}' was not present")
return self.copy(tags=new_tags) def without_tags(self, tags: TagOrTagsType, verify_existence: bool = True) -> Array:
return super.without_tags(tags, verify_existence=verify_existence)
@memoize_method @memoize_method
def __hash__(self) -> int: def __hash__(self) -> int:
...@@ -1234,7 +1227,7 @@ class InputArgumentBase(Array): ...@@ -1234,7 +1227,7 @@ class InputArgumentBase(Array):
def tagged(self, tag: Tag) -> Array: def tagged(self, tag: Tag) -> Array:
raise ValueError("Cannot modify tags") raise ValueError("Cannot modify tags")
def without_tag(self, tag: Tag, verify_existence: bool = True) -> Array: def without_tags(self, tag: Tag, verify_existence: bool = True) -> Array:
raise ValueError("Cannot modify tags") raise ValueError("Cannot modify tags")
# }}} # }}}
......
...@@ -35,7 +35,7 @@ setup(name="pytato", ...@@ -35,7 +35,7 @@ setup(name="pytato",
python_requires="~=3.8", python_requires="~=3.8",
install_requires=[ install_requires=[
"loopy>=2020.2", "loopy>=2020.2",
"pytools>=2020.4.2" "pytools>=2020.4.4"
], ],
author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei", author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment