diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index b5e7f6394f365b347c2ae783a8de886fab3df682..f6263452d68b112964fdb78c25ebcb9100f35af5 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -265,7 +265,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): def _record_leaf_ary_in_dict(key: Tuple[Any, ...], ary: ArrayT): - key_str = "_actx" + _ary_container_key_stringifier(key) + key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary return ary @@ -340,7 +340,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): } def _to_frozen(key: Tuple[Any, ...], ary: ArrayT): - key_str = "_actx" + _ary_container_key_stringifier(key) + key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] return with_array_context(rec_keyed_map_array_container(_to_frozen, @@ -488,7 +488,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): def _record_leaf_ary_in_dict(key: Tuple[Any, ...], ary: Union[DeviceArray, pt.Array]): - key_str = "_actx" + _ary_container_key_stringifier(key) + key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary return ary @@ -524,7 +524,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): } def _to_frozen(key: Tuple[Any, ...], ary: ArrayT): - key_str = "_actx" + _ary_container_key_stringifier(key) + key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] return with_array_context(rec_keyed_map_array_container(_to_frozen,