From 78ee5a1ae3d1f3afe1809023dc3377f4e98960ac Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Fri, 10 Jun 2022 16:24:23 -0500 Subject: [PATCH] Add tag_axes Co-authored-by: Kaushik Kulkarni <kaushikcfd@gmail.com> --- arraycontext/__init__.py | 4 ++-- arraycontext/context.py | 22 +++++++++++++++++++++- test/test_arraycontext.py | 26 ++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 3665cf6..bfbc14f 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -29,7 +29,7 @@ THE SOFTWARE. """ import sys -from .context import ArrayContext, Array, Scalar +from .context import ArrayContext, Array, Scalar, tag_axes from .transform_metadata import (CommonSubexpressionTag, ElementwiseMapKernelTag) @@ -78,7 +78,7 @@ from .loopy import make_loopy_program __all__ = ( - "ArrayContext", "Scalar", "Array", + "ArrayContext", "Scalar", "Array", "tag_axes", "CommonSubexpressionTag", "ElementwiseMapKernelTag", diff --git a/arraycontext/context.py b/arraycontext/context.py index 72bf8c7..b6278e2 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -80,6 +80,7 @@ The interface of an array context .. autoclass:: Array .. autoclass:: Scalar .. autoclass:: ArrayContext +.. autofunction:: tag_axes Internal typing helpers (do not import) --------------------------------------- @@ -122,7 +123,7 @@ THE SOFTWARE. from abc import ABC, abstractmethod from typing import ( - Any, Callable, Dict, Optional, Tuple, Union, + Any, Callable, Dict, Optional, Tuple, Union, Mapping, TYPE_CHECKING, TypeVar) import numpy as np @@ -488,4 +489,23 @@ class ArrayContext(ABC): # }}} + +# {{{ tagging helpers + +def tag_axes( + actx: ArrayContext, + dim_to_tags: Mapping[int, ToTagSetConvertible], + ary: ArrayT) -> ArrayT: + """ + Return a copy of *ary* with the axes in *dim_to_tags* tagged with their + corresponding tags. Equivalent to repeated application of + :meth:`ArrayContext.tag_axis`. + """ + for iaxis, tags in dim_to_tags.items(): + ary = actx.tag_axis(iaxis, tags, ary) + + return ary + +# }}} + # vim: foldmethod=marker diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 0848987..154af2f 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -35,8 +35,9 @@ from arraycontext import ( FirstAxisIsElementsTag, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, + EagerJAXArrayContext, ArrayContainer, - to_numpy) + to_numpy, tag_axes) from arraycontext import ( # noqa: F401 pytest_generate_tests_for_array_contexts, ) @@ -1442,7 +1443,7 @@ def test_actx_compile_on_pure_array_return(actx_factory): # }}} -# {{{ +# {{{ test_taggable_cl_array_tags def test_taggable_cl_array_tags(actx_factory): actx = actx_factory() @@ -1497,6 +1498,27 @@ def test_to_numpy_on_frozen_arrays(actx_factory): np.testing.assert_allclose(to_numpy(u, actx), 1) +def test_tagging(actx_factory): + actx = actx_factory() + + if isinstance(actx, EagerJAXArrayContext): + pytest.skip("Eager JAX has no tagging support") + + from pytools.tag import Tag + + class ExampleTag(Tag): + pass + + ary = tag_axes(actx, {0: ExampleTag()}, + actx.tag( + ExampleTag(), + actx.zeros((20, 20), dtype=np.float64))) + + assert ary.tags_of_type(ExampleTag) + assert ary.axes[0].tags_of_type(ExampleTag) + assert not ary.axes[1].tags_of_type(ExampleTag) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab