From f830a544225e311f83d5ab993903a8953c173e85 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 20 Jun 2021 17:16:26 -0500 Subject: [PATCH] PytatoArrayCompilerOp: account for keys to be tuples --- arraycontext/impl/pytato.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index 153ab14..d690570 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -172,6 +172,23 @@ class ArrayContainerInputDescriptor(AbstractInputDescriptor): Tuple[int, ...]]]" +def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: + """ + Helper for :meth:`PytatoCompiledOperator.__call__`. Stringifies an + array-container's component's key. The aim is that no-two different keys + have the same stringification. + """ + def _rec_str(key: Any) -> str: + if isinstance(key, (str, int)): + return key + elif isinstance(key, tuple): + return "tup" + "_".join(_rec_str(k) for k in key) + "endtup" + else: + raise NotImplementedError + + return "_".join(_rec_str(key) for key in keys) + + @dataclass class PytatoCompiledOperator: actx: ArrayContext @@ -226,8 +243,8 @@ class PytatoCompiledOperator: return pt.make_placeholder((), np.dtype(arg), name) elif is_array_container(arg): def _rec_to_placeholder(keys, ary): - name = f"_actx_in_{pos}_" + "_".join(str(key) - for key in keys) + name = (f"_actx_in_{pos}_" + + _ary_container_key_stringifier(keys)) input_naming_map[(pos,) + keys] = name return pt.make_placeholder(ary.shape, ary.dtype, name) -- GitLab