From 97df833e8439842f8e73ce7cdde87e45e221b732 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Tue, 22 Jun 2021 05:46:07 -0500
Subject: [PATCH] make leaf array descriptor its own class

---
 arraycontext/impl/pytato/compile.py | 27 +++++++++++++++++++++++----
 1 file changed, 23 insertions(+), 4 deletions(-)

diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 43488c4..d83ece6 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -38,7 +38,13 @@ import pyopencl.array as cla
 import pytato as pt
 
 
+# {{{ helper classes: AbstractInputDescriptor
+
 class AbstractInputDescriptor:
+    """
+    Used internally in :class:`LazilyCompilingFunctionCaller` to characterize
+    an input.
+    """
     def __eq__(self, other):
         raise NotImplementedError
 
@@ -51,10 +57,23 @@ class ScalarInputDescriptor(AbstractInputDescriptor):
     dtype: np.dtype
 
 
+@dataclass(frozen=True, eq=True)
+class LeafArrayDescriptor:
+    dtype: np.dtype
+    shape: Tuple[int, ...]
+
+
 @dataclass(frozen=True, eq=True)
 class ArrayContainerInputDescriptor(AbstractInputDescriptor):
-    id_to_ary_descr: "PMap[Tuple[Any, ...], Tuple[np.dtype, \
-                                                  Tuple[int, ...]]]"
+    """
+    .. attribute id_to_ary_descr::
+
+        A mapping from keys of leaf arrays of an array container to their
+        :class:`LeafArrayDescriptor`.
+    """
+    id_to_ary_descr: "PMap[Tuple[Any, ...], LeafArrayDescriptor]"
+
+# }}}
 
 
 def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
@@ -116,8 +135,8 @@ class LazilyCompilingFunctionCaller:
                 id_to_ary_descr = {}
 
                 def id_collector(keys, ary):
-                    id_to_ary_descr[keys] = (np.dtype(ary.dtype),
-                                             ary.shape)
+                    id_to_ary_descr[keys] = LeafArrayDescriptor(np.dtype(ary.dtype),
+                                                                ary.shape)
                     return ary
 
                 rec_keyed_map_array_container(id_collector, arg)
-- 
GitLab