From b8dd7aef42f7bbe261fcf400357addf051b024e7 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com> Date: Fri, 16 Jul 2021 17:52:44 -0500 Subject: [PATCH] Make pt.(Placeholder|SizeParam) equality comparable (#123) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * make pt.(Placeholder|SizeParam) equality comparable The only input argument for which it makes sense to have identity comparison is data wrappers to that the DAG comparisons are efficient. This also led to some downstream re-writing of mappers's cache keys that relied on object equality. * Placeholder without a name is moot No sensible program would actually needed a non-named placeholder. * fixup! make pt.(Placeholder|SizeParam) equality comparable * adds a note about DataWrapper.__eq__'s implementation * fixup! make pt.(Placeholder|SizeParam) equality comparable * Clarify DataWrapper equality definition * Clarify comment explaining use of id() for CachedWalkMapper Co-authored-by: Andreas Klöckner --- pytato/array.py | 48 +++++++++++++++++++++++++++++---------------- pytato/transform.py | 13 +++++++++--- test/test_pytato.py | 8 ++++---- 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 279d2ca..f39109b 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -225,9 +225,12 @@ ConvertibleToShape = Union[ Sequence[ShapeComponent]] -def _check_identifier(s: Optional[str]) -> bool: +def _check_identifier(s: Optional[str], optional: bool) -> bool: if s is None: - return True + if optional: + return True + else: + raise ValueError(f"'{s}' is not a valid identifier") if not s.isidentifier(): raise ValueError(f"'{s}' is not a valid identifier") @@ -1453,12 +1456,6 @@ class InputArgumentBase(Array): super().__init__(tags=tags) self.name = name - def __hash__(self) -> int: - return id(self) - - def __eq__(self, other: Any) -> bool: - return self is other - # }}} @@ -1495,6 +1492,12 @@ class DataWrapper(InputArgumentBase): Starting with the construction of the :class:`DataWrapper`, this array may not be updated in-place. + + .. note:: + + Since we cannot compare instances of :class:`DataInterface` being + wrapped, a :class:`DataWrapper` instances compare equal to themselves + (i.e. the very same instance). """ _fields = InputArgumentBase._fields + ("data", "shape") @@ -1510,6 +1513,12 @@ class DataWrapper(InputArgumentBase): self.data = data self._shape = shape + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: Any) -> bool: + return self is other + @property def shape(self) -> ShapeType: return self._shape @@ -1535,7 +1544,7 @@ class Placeholder(_SuppliedShapeAndDtypeMixin, InputArgumentBase): _mapper_method = "map_placeholder" def __init__(self, - name: Optional[str], + name: str, shape: ShapeType, dtype: np.dtype[Any], tags: TagsType = frozenset()): @@ -1559,6 +1568,11 @@ class SizeParam(InputArgumentBase): _mapper_method = "map_size_param" + def __init__(self, + name: str, + tags: TagsType = frozenset()): + super().__init__(name=name, tags=tags) + @property def shape(self) -> ShapeType: return () @@ -1821,10 +1835,10 @@ def make_dict_of_named_arrays(data: Dict[str, Array]) -> DictOfNamedArrays: # }}} -def make_placeholder(shape: ConvertibleToShape, - dtype: Any, - name: Optional[str] = None, - tags: TagsType = frozenset()) -> Placeholder: +def make_placeholder(name: str, + shape: ConvertibleToShape, + dtype: Any, + tags: TagsType = frozenset()) -> Placeholder: """Make a :class:`Placeholder` object. :param name: name of the placeholder array, generated automatically @@ -1834,7 +1848,7 @@ def make_placeholder(shape: ConvertibleToShape, (must be convertible to :class:`numpy.dtype`) :param tags: implementation tags """ - _check_identifier(name) + _check_identifier(name, optional=False) shape = normalize_shape(shape) dtype = np.dtype(dtype) @@ -1842,7 +1856,7 @@ def make_placeholder(shape: ConvertibleToShape, def make_size_param(name: str, - tags: TagsType = frozenset()) -> SizeParam: + tags: TagsType = frozenset()) -> SizeParam: """Make a :class:`SizeParam`. Size parameters may be used as variables in symbolic expressions for array @@ -1851,7 +1865,7 @@ def make_size_param(name: str, :param name: name :param tags: implementation tags """ - _check_identifier(name) + _check_identifier(name, optional=False) return SizeParam(name, tags=tags) @@ -1866,7 +1880,7 @@ def make_data_wrapper(data: DataInterface, :param shape: optional shape of the array, inferred from *data* if not given :param tags: implementation tags """ - _check_identifier(name) + _check_identifier(name, optional=True) if shape is None: shape = data.shape diff --git a/pytato/transform.py b/pytato/transform.py index 25e52ca..05cd3eb 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -122,6 +122,7 @@ class CopyMapper(Mapper): tags=expr.tags) def map_placeholder(self, expr: Placeholder) -> Array: + assert expr.name is not None return Placeholder(name=expr.name, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), @@ -166,6 +167,7 @@ class CopyMapper(Mapper): tags=expr.tags) def map_size_param(self, expr: SizeParam) -> Array: + assert expr.name is not None return SizeParam(name=expr.name, tags=expr.tags) def map_einsum(self, expr: Einsum) -> Array: @@ -536,18 +538,23 @@ class CachedWalkMapper(WalkMapper): """ def __init__(self) -> None: - self._visited_nodes: Set[ArrayOrNames] = set() + self._visited_ids: Set[int] = set() # type-ignore reason: CachedWalkMapper.rec's type does not match # WalkMapper.rec's type def rec(self, expr: ArrayOrNames) -> None: # type: ignore - if expr in self._visited_nodes: + # Why choose id(x) as the cache key? + # - Some downstream users (NamesValidityChecker) of this mapper rely on + # structurally equal objects being walked separately (e.g. to detect + # separate instances of Placeholder with the same name). + + if id(expr) in self._visited_ids: return # type-ignore reason: super().rec expects either 'Array' or # 'AbstractResultWithNamedArrays', passed 'ArrayOrNames' super().rec(expr) # type: ignore - self._visited_nodes.add(expr) + self._visited_ids.add(id(expr)) # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 99bb883..1d8c6c0 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -109,7 +109,7 @@ def test_stack_input_validation(): @pytest.mark.xfail # Unnamed placeholders should be used via pt.bind def test_make_placeholder_noname(): - x = pt.make_placeholder(shape=(10, 4), dtype=float) + x = pt.make_placeholder("x", shape=(10, 4), dtype=float) y = 2*x knl = pt.generate_loopy(y).kernel @@ -119,7 +119,7 @@ def test_make_placeholder_noname(): def test_zero_length_arrays(): - x = pt.make_placeholder(shape=(0, 4), dtype=float) + x = pt.make_placeholder("x", shape=(0, 4), dtype=float) y = 2*x assert y.shape == (0, 4) @@ -149,7 +149,7 @@ def test_concatenate_input_validation(): def test_reshape_input_validation(): - x = pt.make_placeholder(shape=(3, 3, 4), dtype=np.float64) + x = pt.make_placeholder("x", shape=(3, 3, 4), dtype=np.float64) assert pt.reshape(x, (-1,)).shape == (36,) assert pt.reshape(x, (-1, 6)).shape == (6, 6) @@ -202,7 +202,7 @@ def test_same_placeholder_name_raises(): def test_einsum_error_handling(): with pytest.raises(ValueError): # operands not enough - pt.einsum("ij,j->j", pt.make_placeholder((2, 2), float)) + pt.einsum("ij,j->j", pt.make_placeholder("x", (2, 2), float)) with pytest.raises(ValueError): # double index use in the out spec. -- GitLab