diff --git a/pytato/array.py b/pytato/array.py index 6e3bf69757173d8545a22aff88d5ed822be4ff27..1391103fb33823cffd9f9f897f10e4d803788b01 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -91,7 +91,7 @@ import pytato.scalar_expr as scalar_expr from dataclasses import dataclass from pytools import is_single_valued -from typing import Optional, Dict, Any, Mapping, Iterator, Tuple, Union +from typing import Optional, Dict, Any, Mapping, Iterator, Tuple, Union, FrozenSet # {{{ dotted name @@ -240,7 +240,7 @@ class UniqueTag(Tag): """ -TagsType = Dict[DottedName, Tag] +TagsType = FrozenSet[Tag] # }}} @@ -356,8 +356,7 @@ class Array: .. attribute:: tags - A :class:`dict` mapping :class:`DottedName` instances - to instances of the :class:`Tag` interface. + A :class:`tuple` of :class:`Tag` instances. Motivation: `RDF `__ @@ -377,7 +376,7 @@ class Array: def __init__(self, namespace: Namespace, tags: Optional[TagsType] = None): if tags is None: - tags = {} + tags = frozenset() self.namespace = namespace self.tags = tags @@ -404,11 +403,17 @@ class Array: tags of this type are already present, an error is raised. """ - raise NotImplementedError - return self.copy() + return self.copy(tags=self.tags | frozenset([tag])) - def without_tag(self, dotted_name: DottedName) -> Array: - raise NotImplementedError + def without_tag(self, tag: Tag, verify_existence: bool = True) -> Array: + new_tags = tuple( + t for t in self.tags + if t != tag) + + if verify_existence and len(new_tags) == len(self.tags): + raise ValueError(f"tag '{tag}' was not present") + + return self.copy(tags=new_tags) # TODO: # - codegen interface @@ -536,7 +541,7 @@ class IndexLambda(Array): self, namespace: Namespace, expr: prim.Expression, shape: ShapeType, dtype: np.dtype, bindings: Optional[Dict[str, Array]] = None, - tags: Optional[Dict[DottedName, Tag]] = None): + tags: Optional[TagsType] = None): if bindings is None: bindings = {}