From 78ee5a1ae3d1f3afe1809023dc3377f4e98960ac Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 10 Jun 2022 16:24:23 -0500
Subject: [PATCH] Add tag_axes

Co-authored-by: Kaushik Kulkarni <kaushikcfd@gmail.com>
---
 arraycontext/__init__.py  |  4 ++--
 arraycontext/context.py   | 22 +++++++++++++++++++++-
 test/test_arraycontext.py | 26 ++++++++++++++++++++++++--
 3 files changed, 47 insertions(+), 5 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 3665cf6..bfbc14f 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -29,7 +29,7 @@ THE SOFTWARE.
 """
 
 import sys
-from .context import ArrayContext, Array, Scalar
+from .context import ArrayContext, Array, Scalar, tag_axes
 
 from .transform_metadata import (CommonSubexpressionTag,
         ElementwiseMapKernelTag)
@@ -78,7 +78,7 @@ from .loopy import make_loopy_program
 
 
 __all__ = (
-        "ArrayContext", "Scalar", "Array",
+        "ArrayContext", "Scalar", "Array", "tag_axes",
 
         "CommonSubexpressionTag",
         "ElementwiseMapKernelTag",
diff --git a/arraycontext/context.py b/arraycontext/context.py
index 72bf8c7..b6278e2 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -80,6 +80,7 @@ The interface of an array context
 .. autoclass:: Array
 .. autoclass:: Scalar
 .. autoclass:: ArrayContext
+.. autofunction:: tag_axes
 
 Internal typing helpers (do not import)
 ---------------------------------------
@@ -122,7 +123,7 @@ THE SOFTWARE.
 
 from abc import ABC, abstractmethod
 from typing import (
-        Any, Callable, Dict, Optional, Tuple, Union,
+        Any, Callable, Dict, Optional, Tuple, Union, Mapping,
         TYPE_CHECKING, TypeVar)
 
 import numpy as np
@@ -488,4 +489,23 @@ class ArrayContext(ABC):
 
 # }}}
 
+
+# {{{ tagging helpers
+
+def tag_axes(
+        actx: ArrayContext,
+        dim_to_tags: Mapping[int, ToTagSetConvertible],
+        ary: ArrayT) -> ArrayT:
+    """
+    Return a copy of *ary* with the axes in *dim_to_tags* tagged with their
+    corresponding tags. Equivalent to repeated application of
+    :meth:`ArrayContext.tag_axis`.
+    """
+    for iaxis, tags in dim_to_tags.items():
+        ary = actx.tag_axis(iaxis, tags, ary)
+
+    return ary
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 0848987..154af2f 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -35,8 +35,9 @@ from arraycontext import (
         FirstAxisIsElementsTag,
         PyOpenCLArrayContext,
         PytatoPyOpenCLArrayContext,
+        EagerJAXArrayContext,
         ArrayContainer,
-        to_numpy)
+        to_numpy, tag_axes)
 from arraycontext import (  # noqa: F401
         pytest_generate_tests_for_array_contexts,
         )
@@ -1442,7 +1443,7 @@ def test_actx_compile_on_pure_array_return(actx_factory):
 # }}}
 
 
-# {{{
+# {{{ test_taggable_cl_array_tags
 
 def test_taggable_cl_array_tags(actx_factory):
     actx = actx_factory()
@@ -1497,6 +1498,27 @@ def test_to_numpy_on_frozen_arrays(actx_factory):
     np.testing.assert_allclose(to_numpy(u, actx), 1)
 
 
+def test_tagging(actx_factory):
+    actx = actx_factory()
+
+    if isinstance(actx, EagerJAXArrayContext):
+        pytest.skip("Eager JAX has no tagging support")
+
+    from pytools.tag import Tag
+
+    class ExampleTag(Tag):
+        pass
+
+    ary = tag_axes(actx, {0: ExampleTag()},
+            actx.tag(
+                ExampleTag(),
+                actx.zeros((20, 20), dtype=np.float64)))
+
+    assert ary.tags_of_type(ExampleTag)
+    assert ary.axes[0].tags_of_type(ExampleTag)
+    assert not ary.axes[1].tags_of_type(ExampleTag)
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab