diff --git a/pytato/array.py b/pytato/array.py index 279d2cad2c42797579222ccbb2e9fe78772aa28d..f39109bf16ce9888949cfc7346e442ece2dca72e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -225,9 +225,12 @@ ConvertibleToShape = Union[ Sequence[ShapeComponent]] -def _check_identifier(s: Optional[str]) -> bool: +def _check_identifier(s: Optional[str], optional: bool) -> bool: if s is None: - return True + if optional: + return True + else: + raise ValueError(f"'{s}' is not a valid identifier") if not s.isidentifier(): raise ValueError(f"'{s}' is not a valid identifier") @@ -1453,12 +1456,6 @@ class InputArgumentBase(Array): super().__init__(tags=tags) self.name = name - def __hash__(self) -> int: - return id(self) - - def __eq__(self, other: Any) -> bool: - return self is other - # }}} @@ -1495,6 +1492,12 @@ class DataWrapper(InputArgumentBase): Starting with the construction of the :class:`DataWrapper`, this array may not be updated in-place. + + .. note:: + + Since we cannot compare instances of :class:`DataInterface` being + wrapped, a :class:`DataWrapper` instances compare equal to themselves + (i.e. the very same instance). """ _fields = InputArgumentBase._fields + ("data", "shape") @@ -1510,6 +1513,12 @@ class DataWrapper(InputArgumentBase): self.data = data self._shape = shape + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: Any) -> bool: + return self is other + @property def shape(self) -> ShapeType: return self._shape @@ -1535,7 +1544,7 @@ class Placeholder(_SuppliedShapeAndDtypeMixin, InputArgumentBase): _mapper_method = "map_placeholder" def __init__(self, - name: Optional[str], + name: str, shape: ShapeType, dtype: np.dtype[Any], tags: TagsType = frozenset()): @@ -1559,6 +1568,11 @@ class SizeParam(InputArgumentBase): _mapper_method = "map_size_param" + def __init__(self, + name: str, + tags: TagsType = frozenset()): + super().__init__(name=name, tags=tags) + @property def shape(self) -> ShapeType: return () @@ -1821,10 +1835,10 @@ def make_dict_of_named_arrays(data: Dict[str, Array]) -> DictOfNamedArrays: # }}} -def make_placeholder(shape: ConvertibleToShape, - dtype: Any, - name: Optional[str] = None, - tags: TagsType = frozenset()) -> Placeholder: +def make_placeholder(name: str, + shape: ConvertibleToShape, + dtype: Any, + tags: TagsType = frozenset()) -> Placeholder: """Make a :class:`Placeholder` object. :param name: name of the placeholder array, generated automatically @@ -1834,7 +1848,7 @@ def make_placeholder(shape: ConvertibleToShape, (must be convertible to :class:`numpy.dtype`) :param tags: implementation tags """ - _check_identifier(name) + _check_identifier(name, optional=False) shape = normalize_shape(shape) dtype = np.dtype(dtype) @@ -1842,7 +1856,7 @@ def make_placeholder(shape: ConvertibleToShape, def make_size_param(name: str, - tags: TagsType = frozenset()) -> SizeParam: + tags: TagsType = frozenset()) -> SizeParam: """Make a :class:`SizeParam`. Size parameters may be used as variables in symbolic expressions for array @@ -1851,7 +1865,7 @@ def make_size_param(name: str, :param name: name :param tags: implementation tags """ - _check_identifier(name) + _check_identifier(name, optional=False) return SizeParam(name, tags=tags) @@ -1866,7 +1880,7 @@ def make_data_wrapper(data: DataInterface, :param shape: optional shape of the array, inferred from *data* if not given :param tags: implementation tags """ - _check_identifier(name) + _check_identifier(name, optional=True) if shape is None: shape = data.shape diff --git a/pytato/transform.py b/pytato/transform.py index 25e52ca29d79ed2246c8dea7e9d4ec9e347f1aba..05cd3eb3831462d9ae5f1073eed278de2ccbc1d8 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -122,6 +122,7 @@ class CopyMapper(Mapper): tags=expr.tags) def map_placeholder(self, expr: Placeholder) -> Array: + assert expr.name is not None return Placeholder(name=expr.name, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), @@ -166,6 +167,7 @@ class CopyMapper(Mapper): tags=expr.tags) def map_size_param(self, expr: SizeParam) -> Array: + assert expr.name is not None return SizeParam(name=expr.name, tags=expr.tags) def map_einsum(self, expr: Einsum) -> Array: @@ -536,18 +538,23 @@ class CachedWalkMapper(WalkMapper): """ def __init__(self) -> None: - self._visited_nodes: Set[ArrayOrNames] = set() + self._visited_ids: Set[int] = set() # type-ignore reason: CachedWalkMapper.rec's type does not match # WalkMapper.rec's type def rec(self, expr: ArrayOrNames) -> None: # type: ignore - if expr in self._visited_nodes: + # Why choose id(x) as the cache key? + # - Some downstream users (NamesValidityChecker) of this mapper rely on + # structurally equal objects being walked separately (e.g. to detect + # separate instances of Placeholder with the same name). + + if id(expr) in self._visited_ids: return # type-ignore reason: super().rec expects either 'Array' or # 'AbstractResultWithNamedArrays', passed 'ArrayOrNames' super().rec(expr) # type: ignore - self._visited_nodes.add(expr) + self._visited_ids.add(id(expr)) # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 99bb883d328c7368393e2e68a68419087dd37955..1d8c6c03d642073604f245ce949b5abff12cd6df 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -109,7 +109,7 @@ def test_stack_input_validation(): @pytest.mark.xfail # Unnamed placeholders should be used via pt.bind def test_make_placeholder_noname(): - x = pt.make_placeholder(shape=(10, 4), dtype=float) + x = pt.make_placeholder("x", shape=(10, 4), dtype=float) y = 2*x knl = pt.generate_loopy(y).kernel @@ -119,7 +119,7 @@ def test_make_placeholder_noname(): def test_zero_length_arrays(): - x = pt.make_placeholder(shape=(0, 4), dtype=float) + x = pt.make_placeholder("x", shape=(0, 4), dtype=float) y = 2*x assert y.shape == (0, 4) @@ -149,7 +149,7 @@ def test_concatenate_input_validation(): def test_reshape_input_validation(): - x = pt.make_placeholder(shape=(3, 3, 4), dtype=np.float64) + x = pt.make_placeholder("x", shape=(3, 3, 4), dtype=np.float64) assert pt.reshape(x, (-1,)).shape == (36,) assert pt.reshape(x, (-1, 6)).shape == (6, 6) @@ -202,7 +202,7 @@ def test_same_placeholder_name_raises(): def test_einsum_error_handling(): with pytest.raises(ValueError): # operands not enough - pt.einsum("ij,j->j", pt.make_placeholder((2, 2), float)) + pt.einsum("ij,j->j", pt.make_placeholder("x", (2, 2), float)) with pytest.raises(ValueError): # double index use in the out spec.