Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • inducer/arraycontext
  • kaushikcfd/arraycontext
  • fikl2/arraycontext
3 results
Show changes
Commits on Source (233)
Showing
with 2461 additions and 795 deletions
version: 2
updates:
# Set update schedule for GitHub Actions
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
# vim: sw=4
...@@ -7,14 +7,15 @@ on: ...@@ -7,14 +7,15 @@ on:
jobs: jobs:
autopush: autopush:
name: Automatic push to gitlab.tiker.net name: Automatic push to gitlab.tiker.net
if: startsWith(github.repository, 'inducer/')
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- run: | - run: |
mkdir ~/.ssh && echo -e "Host gitlab.tiker.net\n\tStrictHostKeyChecking no\n" >> ~/.ssh/config curl -L -O https://tiker.net/ci-support-v0
eval $(ssh-agent) && echo "$GITLAB_AUTOPUSH_KEY" | ssh-add - . ./ci-support-v0
git fetch --unshallow mirror_github_to_gitlab
git push "git@gitlab.tiker.net:inducer/$(basename $GITHUB_REPOSITORY).git" main
env: env:
GITLAB_AUTOPUSH_KEY: ${{ secrets.GITLAB_AUTOPUSH_KEY }} GITLAB_AUTOPUSH_KEY: ${{ secrets.GITLAB_AUTOPUSH_KEY }}
......
...@@ -7,27 +7,35 @@ on: ...@@ -7,27 +7,35 @@ on:
schedule: schedule:
- cron: '17 3 * * 0' - cron: '17 3 * * 0'
concurrency:
group: ${{ github.head_ref || github.ref_name }}
cancel-in-progress: true
jobs: jobs:
flake8: typos:
name: Flake8 name: Typos
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- uses: crate-ci/typos@master
ruff:
name: Ruff
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- -
uses: actions/setup-python@v1 uses: actions/setup-python@v5
with:
# matches compat target in setup.py
python-version: '3.6'
- name: "Main Script" - name: "Main Script"
run: | run: |
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh pip install ruff
. ./prepare-and-run-flake8.sh "$(basename $GITHUB_REPOSITORY)" test examples ruff check
pylint: pylint:
name: Pylint name: Pylint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: "Main Script" - name: "Main Script"
run: | run: |
USE_CONDA_BUILD=1 USE_CONDA_BUILD=1
...@@ -38,24 +46,21 @@ jobs: ...@@ -38,24 +46,21 @@ jobs:
name: Mypy name: Mypy
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
-
uses: actions/setup-python@v1
with:
python-version: '3.x'
- name: "Main Script" - name: "Main Script"
run: | run: |
curl -L -O https://tiker.net/ci-support-v0 curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0 . ./ci-support-v0
build_py_project_in_conda_env build_py_project_in_conda_env
python -m pip install mypy python -m pip install mypy pytest
./run-mypy.sh ./run-mypy.sh
pytest3_pocl: pytest3_pocl:
name: Pytest Conda Py3 POCL name: Pytest Conda Py3 POCL
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: "Main Script" - name: "Main Script"
run: | run: |
curl -L -O https://tiker.net/ci-support-v0 curl -L -O https://tiker.net/ci-support-v0
...@@ -67,7 +72,7 @@ jobs: ...@@ -67,7 +72,7 @@ jobs:
name: Pytest Conda Py3 Intel name: Pytest Conda Py3 Intel
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: "Main Script" - name: "Main Script"
run: | run: |
curl -L -O https://raw.githubusercontent.com/illinois-scicomp/machine-shop-maintenance/main/install-intel-icd.sh curl -L -O https://raw.githubusercontent.com/illinois-scicomp/machine-shop-maintenance/main/install-intel-icd.sh
...@@ -88,7 +93,7 @@ jobs: ...@@ -88,7 +93,7 @@ jobs:
name: Examples Conda Py3 name: Examples Conda Py3
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: "Main Script" - name: "Main Script"
run: | run: |
export MPLBACKEND=Agg export MPLBACKEND=Agg
...@@ -100,9 +105,9 @@ jobs: ...@@ -100,9 +105,9 @@ jobs:
name: Documentation name: Documentation
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- -
uses: actions/setup-python@v1 uses: actions/setup-python@v5
with: with:
python-version: '3.x' python-version: '3.x'
- name: "Main Script" - name: "Main Script"
...@@ -118,47 +123,20 @@ jobs: ...@@ -118,47 +123,20 @@ jobs:
downstream_tests: downstream_tests:
strategy: strategy:
matrix: matrix:
#downstream_project: [meshmode, grudge, pytential, mirgecom] downstream_project: [meshmode, grudge, mirgecom, mirgecom_examples]
downstream_project: [meshmode, grudge, mirgecom]
fail-fast: false fail-fast: false
name: Tests for downstream project ${{ matrix.downstream_project }} name: Tests for downstream project ${{ matrix.downstream_project }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: "Main Script" - name: "Main Script"
env: env:
DOWNSTREAM_PROJECT: ${{ matrix.downstream_project }} DOWNSTREAM_PROJECT: ${{ matrix.downstream_project }}
run: | run: |
curl -L -O https://tiker.net/ci-support-v0 curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0 . ./ci-support-v0
test_downstream "$DOWNSTREAM_PROJECT"
if test "$DOWNSTREAM_PROJECT" = "mirgecom"; then
git clone "https://github.com/illinois-ceesd/$DOWNSTREAM_PROJECT.git"
else
git clone "https://github.com/inducer/$DOWNSTREAM_PROJECT.git"
fi
cd "$DOWNSTREAM_PROJECT"
echo "*** $DOWNSTREAM_PROJECT version: $(git rev-parse --short HEAD)"
# Use this version of arraycontext instead of what downstream would install
edit_requirements_txt_for_downstream_in_subdir
# Avoid slow or complicated tests in downstream projects
export PYTEST_ADDOPTS="-k 'not (slowtest or octave or mpi)'"
if test "$DOWNSTREAM_PROJECT" = "mirgecom"; then
# can't turn off MPI in mirgecom
export CONDA_ENVIRONMENT=conda-env.yml
export CISUPPORT_PARALLEL_PYTEST=no
echo "- mpi4py" >> "$CONDA_ENVIRONMENT"
else
sed -i "/mpi4py/ d" requirements.txt
fi
build_py_project_in_conda_env
test_py_project
if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then
python ../examples/simple-dg.py --lazy python ../examples/simple-dg.py --lazy
......
Python 3 POCL: Python 3 POCL:
script: | script: |
export PY_EXE=python3 export PYOPENCL_TEST=portable:cpu
export PYOPENCL_TEST=portable:pthread export EXTRA_INSTALL="jax[cpu]"
# cython is here because pytential (for now, for TS) depends on it export JAX_PLATFORMS=cpu
export EXTRA_INSTALL="pybind11 cython numpy mako mpi4py oct2py"
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh
. ./build-and-test-py-project.sh . ./build-and-test-py-project.sh
tags: tags:
...@@ -18,12 +17,30 @@ Python 3 POCL: ...@@ -18,12 +17,30 @@ Python 3 POCL:
Python 3 Nvidia Titan V: Python 3 Nvidia Titan V:
script: | script: |
export PY_EXE=python3 curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
export PYOPENCL_TEST=nvi:titan export PYOPENCL_TEST=nvi:titan
export EXTRA_INSTALL="pybind11 cython numpy mako oct2py" build_py_project_in_venv
# cython is here because pytential (for now, for TS) depends on it pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh test_py_project
. ./build-and-test-py-project.sh
tags:
- python3
- nvidia-titan-v
except:
- tags
artifacts:
reports:
junit: test/pytest.xml
Python 3 POCL Nvidia Titan V:
script: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
export PYOPENCL_TEST=port:titan
build_py_project_in_venv
test_py_project
tags: tags:
- python3 - python3
- nvidia-titan-v - nvidia-titan-v
...@@ -36,9 +53,7 @@ Python 3 Nvidia Titan V: ...@@ -36,9 +53,7 @@ Python 3 Nvidia Titan V:
Python 3 POCL Examples: Python 3 POCL Examples:
script: script:
- test -n "$SKIP_EXAMPLES" && exit - test -n "$SKIP_EXAMPLES" && exit
- export PY_EXE=python3 - export PYOPENCL_TEST=portable:cpu
- export PYOPENCL_TEST=portable:pthread
- export EXTRA_INSTALL="pybind11 numpy mako"
- curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-py-project-and-run-examples.sh - curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-py-project-and-run-examples.sh
- ". ./build-py-project-and-run-examples.sh" - ". ./build-py-project-and-run-examples.sh"
tags: tags:
...@@ -50,6 +65,11 @@ Python 3 POCL Examples: ...@@ -50,6 +65,11 @@ Python 3 POCL Examples:
Python 3 Conda: Python 3 Conda:
script: | script: |
export PYOPENCL_TEST=portable:cpu
# Avoid crashes like https://gitlab.tiker.net/inducer/arraycontext/-/jobs/536021
sed -i 's/jax/jax !=0.4.6/' .test-conda-env-py3.yml
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh
. ./build-and-test-py-project-within-miniconda.sh . ./build-and-test-py-project-within-miniconda.sh
tags: tags:
...@@ -61,26 +81,24 @@ Python 3 Conda: ...@@ -61,26 +81,24 @@ Python 3 Conda:
Documentation: Documentation:
script: | script: |
EXTRA_INSTALL="pybind11 cython numpy"
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh
CI_SUPPORT_SPHINX_VERSION_SPECIFIER=">=4.0" CI_SUPPORT_SPHINX_VERSION_SPECIFIER=">=4.0"
. ./build-docs.sh . ./build-docs.sh
tags: tags:
- python3 - python3
Flake8: Ruff:
script: script:
- curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh - pipx install ruff
- . ./prepare-and-run-flake8.sh "$CI_PROJECT_NAME" test examples - ruff check
tags: tags:
- python3 - docker-runner
except: except:
- tags - tags
Pylint: Pylint:
script: | script: |
export PY_EXE=python3 EXTRA_INSTALL="jax[cpu]"
EXTRA_INSTALL="pybind11 numpy mako"
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh
. ./prepare-and-run-pylint.sh "$CI_PROJECT_NAME" examples/*.py test/test_*.py . ./prepare-and-run-pylint.sh "$CI_PROJECT_NAME" examples/*.py test/test_*.py
tags: tags:
...@@ -90,12 +108,30 @@ Pylint: ...@@ -90,12 +108,30 @@ Pylint:
Mypy: Mypy:
script: | script: |
EXTRA_INSTALL="mypy pytest"
curl -L -O https://tiker.net/ci-support-v0 curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0 . ./ci-support-v0
build_py_project_in_venv build_py_project_in_venv
python -m pip install mypy
./run-mypy.sh ./run-mypy.sh
tags: tags:
- python3 - python3
except: except:
- tags - tags
Downstream:
parallel:
matrix:
- DOWNSTREAM_PROJECT: [meshmode, grudge, mirgecom, mirgecom_examples]
tags:
- large-node
- "docker-runner"
script: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
test_downstream "$DOWNSTREAM_PROJECT"
if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then
python ../examples/simple-dg.py --lazy
fi
...@@ -8,8 +8,10 @@ dependencies: ...@@ -8,8 +8,10 @@ dependencies:
- git - git
- libhwloc=2 - libhwloc=2
- numpy - numpy
- pocl # pocl 3.1 required for full SVM functionality
- pocl>=3.1
- mako - mako
- pyopencl - pyopencl
- islpy - islpy
- pip - pip
- jax
include doc/*.rst
include doc/conf.py
include doc/make.bat
include doc/Makefile
include examples/*.py
...@@ -7,7 +7,7 @@ arraycontext: Choose your favorite ``numpy``-workalike ...@@ -7,7 +7,7 @@ arraycontext: Choose your favorite ``numpy``-workalike
.. image:: https://github.com/inducer/arraycontext/workflows/CI/badge.svg .. image:: https://github.com/inducer/arraycontext/workflows/CI/badge.svg
:alt: Github Build Status :alt: Github Build Status
:target: https://github.com/inducer/arraycontext/actions?query=branch%3Amain+workflow%3ACI :target: https://github.com/inducer/arraycontext/actions?query=branch%3Amain+workflow%3ACI
.. image:: https://badge.fury.io/py/arraycontext.png .. image:: https://badge.fury.io/py/arraycontext.svg
:alt: Python Package Index Release Page :alt: Python Package Index Release Page
:target: https://pypi.org/project/arraycontext/ :target: https://pypi.org/project/arraycontext/
...@@ -17,7 +17,9 @@ implementations for: ...@@ -17,7 +17,9 @@ implementations for:
- numpy - numpy
- `PyOpenCL <https://documen.tician.de/pyopencl/array.html>`__ - `PyOpenCL <https://documen.tician.de/pyopencl/array.html>`__
- `JAX <https://jax.readthedocs.io/en/latest/>`__
- `Pytato <https://documen.tician.de/pytato>`__ (for lazy/deferred evaluation) - `Pytato <https://documen.tician.de/pytato>`__ (for lazy/deferred evaluation)
with backends for ``pyopencl`` and ``jax``.
- Debugging - Debugging
- Profiling - Profiling
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
An array context is an abstraction that helps you dispatch between multiple An array context is an abstraction that helps you dispatch between multiple
implementations of :mod:`numpy`-like :math:`n`-dimensional arrays. implementations of :mod:`numpy`-like :math:`n`-dimensional arrays.
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
...@@ -28,84 +29,140 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -28,84 +29,140 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
import sys
from .context import ArrayContext, DeviceArray, DeviceScalar
from .transform_metadata import (CommonSubexpressionTag,
ElementwiseMapKernelTag)
# deprecated, remove in 2022.
from .metadata import _FirstAxisIsElementsTag
from .container import ( from .container import (
ArrayContainer, NotAnArrayContainerError, ArithArrayContainer,
is_array_container, is_array_container_type, ArrayContainer,
get_container_context, get_container_context_recursively, ArrayContainerT,
serialize_container, deserialize_container, NotAnArrayContainerError,
register_multivector_as_array_container) SerializationKey,
from .container.arithmetic import with_container_arithmetic SerializedContainer,
deserialize_container,
get_container_context_opt,
get_container_context_recursively,
get_container_context_recursively_opt,
is_array_container,
is_array_container_type,
register_multivector_as_array_container,
serialize_container,
)
from .container.arithmetic import (
with_container_arithmetic,
)
from .container.dataclass import dataclass_array_container from .container.dataclass import dataclass_array_container
from .container.traversal import ( from .container.traversal import (
map_array_container, flat_size_and_dtype,
multimap_array_container, flatten,
rec_map_array_container, freeze,
rec_multimap_array_container, from_numpy,
mapped_over_array_containers, map_array_container,
multimapped_over_array_containers, map_reduce_array_container,
map_reduce_array_container, mapped_over_array_containers,
multimap_reduce_array_container, multimap_array_container,
rec_map_reduce_array_container, multimap_reduce_array_container,
rec_multimap_reduce_array_container, multimapped_over_array_containers,
thaw, freeze, outer,
flatten, unflatten, rec_map_array_container,
from_numpy, to_numpy, rec_map_reduce_array_container,
outer) rec_multimap_array_container,
rec_multimap_reduce_array_container,
stringify_array_container_tree,
thaw,
to_numpy,
unflatten,
with_array_context,
)
from .context import (
Array,
ArrayContext,
ArrayOrArithContainer,
ArrayOrArithContainerOrScalar,
ArrayOrArithContainerOrScalarT,
ArrayOrArithContainerT,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ArrayOrContainerT,
ArrayT,
Scalar,
ScalarLike,
tag_axes,
)
from .impl.jax import EagerJAXArrayContext
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoPyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .pytest import (
PytestPyOpenCLArrayContextFactory,
pytest_generate_tests_for_array_contexts,
pytest_generate_tests_for_pyopencl_array_context)
from .loopy import make_loopy_program from .loopy import make_loopy_program
from .pytest import (
PytestArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
pytest_generate_tests_for_array_contexts,
)
from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
__all__ = ( __all__ = (
"ArrayContext", "DeviceScalar", "DeviceArray", "ArithArrayContainer",
"Array",
"CommonSubexpressionTag", "ArrayContainer",
"ElementwiseMapKernelTag", "ArrayContainerT",
"ArrayContext",
"ArrayContainer", "NotAnArrayContainerError", "ArrayOrArithContainer",
"is_array_container", "is_array_container_type", "ArrayOrArithContainerOrScalar",
"get_container_context", "get_container_context_recursively", "ArrayOrArithContainerOrScalarT",
"serialize_container", "deserialize_container", "ArrayOrArithContainerT",
"register_multivector_as_array_container", "ArrayOrContainer",
"with_container_arithmetic", "ArrayOrContainerOrScalar",
"dataclass_array_container", "ArrayOrContainerOrScalarT",
"ArrayOrContainerT",
"map_array_container", "multimap_array_container", "ArrayT",
"rec_map_array_container", "rec_multimap_array_container", "CommonSubexpressionTag",
"mapped_over_array_containers", "EagerJAXArrayContext",
"multimapped_over_array_containers", "ElementwiseMapKernelTag",
"map_reduce_array_container", "multimap_reduce_array_container", "NotAnArrayContainerError",
"rec_map_reduce_array_container", "rec_multimap_reduce_array_container", "NumpyArrayContext",
"thaw", "freeze", "PyOpenCLArrayContext",
"flatten", "unflatten", "PytatoJAXArrayContext",
"from_numpy", "to_numpy", "PytatoPyOpenCLArrayContext",
"outer", "PytestArrayContextFactory",
"PytestPyOpenCLArrayContextFactory",
"PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", "Scalar",
"ScalarLike",
"make_loopy_program", "SerializationKey",
"SerializedContainer",
"PytestPyOpenCLArrayContextFactory", "dataclass_array_container",
"pytest_generate_tests_for_array_contexts", "deserialize_container",
"pytest_generate_tests_for_pyopencl_array_context" "flat_size_and_dtype",
) "flatten",
"freeze",
"from_numpy",
"get_container_context_opt",
"get_container_context_recursively",
"get_container_context_recursively_opt",
"is_array_container",
"is_array_container_type",
"make_loopy_program",
"map_array_container",
"map_reduce_array_container",
"mapped_over_array_containers",
"multimap_array_container",
"multimap_reduce_array_container",
"multimapped_over_array_containers",
"outer",
"pytest_generate_tests_for_array_contexts",
"rec_map_array_container",
"rec_map_reduce_array_container",
"rec_multimap_array_container",
"rec_multimap_reduce_array_container",
"register_multivector_as_array_container",
"serialize_container",
"stringify_array_container_tree",
"tag_axes",
"thaw",
"to_numpy",
"unflatten",
"with_array_context",
"with_container_arithmetic",
)
# {{{ deprecation handling # {{{ deprecation handling
...@@ -122,29 +179,24 @@ def _deprecated_acf(): ...@@ -122,29 +179,24 @@ def _deprecated_acf():
_depr_name_to_replacement_and_obj = { _depr_name_to_replacement_and_obj = {
"FirstAxisIsElementsTag": "get_container_context": (
("meshmode.transform_metadata.FirstAxisIsElementsTag", "get_container_context_opt",
_FirstAxisIsElementsTag), get_container_context_opt, 2022),
"_acf":
("<no replacement yet>", _deprecated_acf),
} }
if sys.version_info >= (3, 7):
def __getattr__(name): def __getattr__(name):
replacement_and_obj = _depr_name_to_replacement_and_obj.get(name, None) replacement_and_obj = _depr_name_to_replacement_and_obj.get(name)
if replacement_and_obj is not None: if replacement_and_obj is not None:
replacement, obj = replacement_and_obj replacement, obj, year = replacement_and_obj
from warnings import warn from warnings import warn
warn(f"'arraycontext.{name}' is deprecated. " warn(f"'arraycontext.{name}' is deprecated. "
f"Use '{replacement}' instead. " f"Use '{replacement}' instead. "
f"'arraycontext.{name}' will continue to work until 2022.", f"'arraycontext.{name}' will continue to work until {year}.",
DeprecationWarning, stacklevel=2) DeprecationWarning, stacklevel=2)
return obj return obj
else: else:
raise AttributeError(name) raise AttributeError(name)
else:
FirstAxisIsElementsTag = _FirstAxisIsElementsTag
_acf = _deprecated_acf
# }}} # }}}
......
...@@ -3,42 +3,58 @@ ...@@ -3,42 +3,58 @@
""" """
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
.. class:: ArrayT
:canonical: arraycontext.container.ArrayT
:class:`~typing.TypeVar` for arrays.
.. class:: ContainerT
:canonical: arraycontext.container.ContainerT
:class:`~typing.TypeVar` for array container-like objects.
.. class:: ArrayOrContainerT
:canonical: arraycontext.container.ArrayOrContainerT
:class:`~typing.TypeVar` for arrays or array container-like objects.
.. autoclass:: ArrayContainer .. autoclass:: ArrayContainer
.. autoclass:: ArithArrayContainer
.. class:: ArrayContainerT
A type variable with a lower bound of :class:`ArrayContainer`.
.. autoexception:: NotAnArrayContainerError .. autoexception:: NotAnArrayContainerError
Serialization/deserialization Serialization/deserialization
----------------------------- -----------------------------
.. autoclass:: SerializationKey
.. autoclass:: SerializedContainer
.. autofunction:: is_array_container_type .. autofunction:: is_array_container_type
.. autofunction:: serialize_container .. autofunction:: serialize_container
.. autofunction:: deserialize_container .. autofunction:: deserialize_container
Context retrieval Context retrieval
----------------- -----------------
.. autofunction:: get_container_context .. autofunction:: get_container_context_opt
.. autofunction:: get_container_context_recursively .. autofunction:: get_container_context_recursively
.. autofunction:: get_container_context_recursively_opt
:class:`~pymbolic.geometric_algebra.MultiVector` support :class:`~pymbolic.geometric_algebra.MultiVector` support
--------------------------------------------------------- ---------------------------------------------------------
.. autofunction:: register_multivector_as_array_container .. autofunction:: register_multivector_as_array_container
.. currentmodule:: arraycontext.container
Canonical locations for type annotations
----------------------------------------
.. class:: ArrayContainerT
:canonical: arraycontext.ArrayContainerT
.. class:: ArrayOrContainerT
:canonical: arraycontext.ArrayOrContainerT
.. class:: SerializationKey
:canonical: arraycontext.SerializationKey
.. class:: SerializedContainer
:canonical: arraycontext.SerializedContainer
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
...@@ -64,24 +80,30 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -64,24 +80,30 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from collections.abc import Hashable, Sequence
from functools import singledispatch from functools import singledispatch
from arraycontext.context import ArrayContext from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar
from typing import Any, Iterable, Tuple, TypeVar, Optional, Union, TYPE_CHECKING
# For use in singledispatch type annotations, because sphinx can't figure out
# what 'np' is.
import numpy
import numpy as np import numpy as np
from typing_extensions import Self
from arraycontext.context import ArrayContext, ArrayOrScalar
ArrayT = TypeVar("ArrayT")
ContainerT = TypeVar("ContainerT")
ArrayOrContainerT = Union[ArrayT, ContainerT]
if TYPE_CHECKING: if TYPE_CHECKING:
from pymbolic.geometric_algebra import MultiVector from pymbolic.geometric_algebra import MultiVector
from arraycontext import ArrayOrContainer
# {{{ ArrayContainer # {{{ ArrayContainer
class ArrayContainer: class ArrayContainer(Protocol):
r""" """
A generic container for the array type supported by the A protocol for generic containers of the array type supported by the
:class:`ArrayContext`. :class:`ArrayContext`.
The functionality required for the container to operated is supplied via The functionality required for the container to operated is supplied via
...@@ -92,7 +114,7 @@ class ArrayContainer: ...@@ -92,7 +114,7 @@ class ArrayContainer:
of the array. of the array.
* :func:`deserialize_container` for deserialization, which constructs a * :func:`deserialize_container` for deserialization, which constructs a
container from a set of components. container from a set of components.
* :func:`get_container_context` retrieves the :class:`ArrayContext` from * :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
a container, if it has one. a container, if it has one.
This allows enumeration of the component arrays in a container and the This allows enumeration of the component arrays in a container and the
...@@ -112,31 +134,70 @@ class ArrayContainer: ...@@ -112,31 +134,70 @@ class ArrayContainer:
.. note:: .. note::
This class is used in type annotation. Inheriting from it confers no This class is used in type annotation and as a marker of array container
special meaning or behavior. attributes for :func:`~arraycontext.dataclass_array_container`.
As a protocol, it is not intended as a superclass.
"""
# Array containers do not need to have any particular features, so this
# protocol is deliberately empty.
# This *is* used as a type annotation in dataclasses that are processed
# by dataclass_array_container, where it's used to recognize attributes
# that are container-typed.
class ArithArrayContainer(ArrayContainer, Protocol):
""" """
A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
"""
# This is loose and permissive, assuming that any array can be added
# to any container. The alternative would be to plaster type-ignores
# on all those uses. Achieving typing precision on what broadcasting is
# allowable seems like a huge endeavor and is likely not feasible without
# a mypy plugin. Maybe some day? -AK, November 2024
def __neg__(self) -> Self: ...
def __abs__(self) -> Self: ...
def __add__(self, other: ArrayOrScalar | Self) -> Self: ...
def __radd__(self, other: ArrayOrScalar | Self) -> Self: ...
def __sub__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ...
def __mul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ...
ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
class NotAnArrayContainerError(TypeError): class NotAnArrayContainerError(TypeError):
""":class:`TypeError` subclass raised when an array container is expected.""" """:class:`TypeError` subclass raised when an array container is expected."""
SerializationKey: TypeAlias = Hashable
SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]]
@singledispatch @singledispatch
def serialize_container(ary: Any) -> Iterable[Tuple[Any, Any]]: def serialize_container(
r"""Serialize the array container into an iterable over its components. ary: ArrayContainer) -> SerializedContainer:
r"""Serialize the array container into a sequence over its components.
The order of the components and their identifiers are entirely under The order of the components and their identifiers are entirely under
the control of the container class. However, the order is required to be the control of the container class. However, the order is required to be
deterministic, i.e. two calls to :func:`serialize_container` on deterministic, i.e. two calls to :func:`serialize_container` on
array containers of the same types with the same number of array containers of the same types with the same number of
sub-arrays must result in an iterable with the keys in the same sub-arrays must result in a sequence with the keys in the same
order. order.
If *ary* is mutable, the serialization function is not required to ensure If *ary* is mutable, the serialization function is not required to ensure
that the serialization result reflects the array state at the time of the that the serialization result reflects the array state at the time of the
call to :func:`serialize_container`. call to :func:`serialize_container`.
:returns: an :class:`Iterable` of 2-tuples where the first :returns: a :class:`Sequence` of 2-tuples where the first
entry is an identifier for the component and the second entry entry is an identifier for the component and the second entry
is an array-like component of the :class:`ArrayContainer`. is an array-like component of the :class:`ArrayContainer`.
Components can themselves be :class:`ArrayContainer`\ s, allowing Components can themselves be :class:`ArrayContainer`\ s, allowing
...@@ -148,13 +209,15 @@ def serialize_container(ary: Any) -> Iterable[Tuple[Any, Any]]: ...@@ -148,13 +209,15 @@ def serialize_container(ary: Any) -> Iterable[Tuple[Any, Any]]:
@singledispatch @singledispatch
def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) -> Any: def deserialize_container(
"""Deserialize an iterable into an array container. template: ArrayContainerT,
serialized: SerializedContainer) -> ArrayContainerT:
"""Deserialize a sequence into an array container following a *template*.
:param template: an instance of an existing object that :param template: an instance of an existing object that
can be used to aid in the deserialization. For a similar choice can be used to aid in the deserialization. For a similar choice
see :attr:`~numpy.class.__array_finalize__`. see :attr:`~numpy.class.__array_finalize__`.
:param iterable: an iterable that mirrors the output of :param serialized: a sequence that mirrors the output of
:meth:`serialize_container`. :meth:`serialize_container`.
""" """
raise NotAnArrayContainerError( raise NotAnArrayContainerError(
...@@ -181,7 +244,7 @@ def is_array_container_type(cls: type) -> bool: ...@@ -181,7 +244,7 @@ def is_array_container_type(cls: type) -> bool:
is not serialize_container.__wrapped__)) # type:ignore[attr-defined] is not serialize_container.__wrapped__)) # type:ignore[attr-defined]
def is_array_container(ary: Any) -> bool: def is_array_container(ary: object) -> bool:
""" """
:returns: *True* if the instance *ary* has a registered implementation of :returns: *True* if the instance *ary* has a registered implementation of
:func:`serialize_container`. :func:`serialize_container`.
...@@ -194,11 +257,15 @@ def is_array_container(ary: Any) -> bool: ...@@ -194,11 +257,15 @@ def is_array_container(ary: Any) -> bool:
"cheaper option, see is_array_container_type.", "cheaper option, see is_array_container_type.",
DeprecationWarning, stacklevel=2) DeprecationWarning, stacklevel=2)
return (serialize_container.dispatch(ary.__class__) return (serialize_container.dispatch(ary.__class__)
is not serialize_container.__wrapped__) # type:ignore[attr-defined] is not serialize_container.__wrapped__ # type:ignore[attr-defined]
# numpy values with scalar elements aren't array containers
and not (isinstance(ary, np.ndarray)
and ary.dtype.kind != "O")
)
@singledispatch @singledispatch
def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]: def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None:
"""Retrieves the :class:`ArrayContext` from the container, if any. """Retrieves the :class:`ArrayContext` from the container, if any.
This function is not recursive, so it will only search at the root level This function is not recursive, so it will only search at the root level
...@@ -213,10 +280,11 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]: ...@@ -213,10 +280,11 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]:
# {{{ object arrays as array containers # {{{ object arrays as array containers
@serialize_container.register(np.ndarray) @serialize_container.register(np.ndarray)
def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]: def _serialize_ndarray_container(
ary: numpy.ndarray) -> SerializedContainer:
if ary.dtype.char != "O": if ary.dtype.char != "O":
raise NotAnArrayContainerError( raise NotAnArrayContainerError(
f"cannot seriealize '{type(ary).__name__}' with dtype '{ary.dtype}'") f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'")
# special-cased for speed # special-cased for speed
if ary.ndim == 1: if ary.ndim == 1:
...@@ -227,20 +295,22 @@ def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]: ...@@ -227,20 +295,22 @@ def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]:
for j in range(ary.shape[1]) for j in range(ary.shape[1])
] ]
else: else:
return np.ndenumerate(ary) return list(np.ndenumerate(ary))
@deserialize_container.register(np.ndarray) @deserialize_container.register(np.ndarray)
def _deserialize_ndarray_container( # https://github.com/python/mypy/issues/13040
template: np.ndarray, def _deserialize_ndarray_container( # type: ignore[misc]
iterable: Iterable[Tuple[Any, Any]]) -> np.ndarray: template: numpy.ndarray,
serialized: SerializedContainer) -> numpy.ndarray:
# disallow subclasses # disallow subclasses
assert type(template) is np.ndarray assert type(template) is np.ndarray
assert template.dtype.char == "O" assert template.dtype.char == "O"
result = type(template)(template.shape, dtype=object) result = type(template)(template.shape, dtype=object)
for i, subary in iterable: for i, subary in serialized:
result[i] = subary # FIXME: numpy annotations don't seem to handle object arrays very well
result[i] = subary # type: ignore[call-overload]
return result return result
...@@ -249,15 +319,18 @@ def _deserialize_ndarray_container( ...@@ -249,15 +319,18 @@ def _deserialize_ndarray_container(
# {{{ get_container_context_recursively # {{{ get_container_context_recursively
def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: def get_container_context_recursively_opt(
ary: ArrayContainer) -> ArrayContext | None:
"""Walks the :class:`ArrayContainer` hierarchy to find an """Walks the :class:`ArrayContainer` hierarchy to find an
:class:`ArrayContext` associated with it. :class:`ArrayContext` associated with it.
If different components that have different array contexts are found at If different components that have different array contexts are found at
any level, an assertion error is raised. any level, an assertion error is raised.
Returns *None* if no array context was found.
""" """
# try getting the array context directly # try getting the array context directly
actx = get_container_context(ary) actx = get_container_context_opt(ary)
if actx is not None: if actx is not None:
return actx return actx
...@@ -267,7 +340,7 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: ...@@ -267,7 +340,7 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]:
return actx return actx
else: else:
for _, subary in iterable: for _, subary in iterable:
context = get_container_context_recursively(subary) context = get_container_context_recursively_opt(subary)
if context is None: if context is None:
continue continue
...@@ -280,6 +353,28 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: ...@@ -280,6 +353,28 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]:
return actx return actx
def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None:
"""Walks the :class:`ArrayContainer` hierarchy to find an
:class:`ArrayContext` associated with it.
If different components that have different array contexts are found at
any level, an assertion error is raised.
Raises an error if no array container is found.
"""
actx = get_container_context_recursively_opt(ary)
if actx is None:
# raise ValueError("no array context was found")
from warnings import warn
warn("No array context was found. This will be an error starting in "
"July of 2022. If you would like the function to return "
"None if no array context was found, use "
"get_container_context_recursively_opt.",
DeprecationWarning, stacklevel=2)
return actx
# }}} # }}}
...@@ -288,17 +383,19 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: ...@@ -288,17 +383,19 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]:
# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic. # FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic.
# (Though clearly there exists a dependency via loopy.) # (Though clearly there exists a dependency via loopy.)
def _serialize_multivec_as_container(mv: "MultiVector") -> Iterable[Tuple[Any, Any]]: def _serialize_multivec_as_container(mv: MultiVector) -> SerializedContainer:
return list(mv.data.items()) return list(mv.data.items())
def _deserialize_multivec_as_container(template: "MultiVector", # FIXME: Ignored due to https://github.com/python/mypy/issues/13040
iterable: Iterable[Tuple[Any, Any]]) -> "MultiVector": def _deserialize_multivec_as_container( # type: ignore[misc]
template: MultiVector,
serialized: SerializedContainer) -> MultiVector:
from pymbolic.geometric_algebra import MultiVector from pymbolic.geometric_algebra import MultiVector
return MultiVector(dict(iterable), space=template.space) return MultiVector(dict(serialized), space=template.space)
def _get_container_context_from_multivec(mv: "MultiVector") -> None: def _get_container_context_opt_from_multivec(mv: MultiVector) -> None:
return None return None
...@@ -312,8 +409,8 @@ def register_multivector_as_array_container() -> None: ...@@ -312,8 +409,8 @@ def register_multivector_as_array_container() -> None:
serialize_container.register(MultiVector)(_serialize_multivec_as_container) serialize_container.register(MultiVector)(_serialize_multivec_as_container)
deserialize_container.register(MultiVector)( deserialize_container.register(MultiVector)(
_deserialize_multivec_as_container) _deserialize_multivec_as_container)
get_container_context.register(MultiVector)( get_container_context_opt.register(MultiVector)(
_get_container_context_from_multivec) _get_container_context_opt_from_multivec)
assert MultiVector in serialize_container.registry assert MultiVector in serialize_container.registry
# }}} # }}}
......
# mypy: disallow-untyped-defs # mypy: disallow-untyped-defs
from __future__ import annotations
"""
__doc__ = """
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
.. autofunction:: with_container_arithmetic .. autofunction:: with_container_arithmetic
""" """
import enum
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
...@@ -31,7 +33,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -31,7 +33,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from typing import Any, Callable, Optional, Tuple, TypeVar, Union import enum
from collections.abc import Callable
from typing import Any, TypeVar
from warnings import warn
import numpy as np import numpy as np
...@@ -86,7 +91,7 @@ _BINARY_OP_AND_DUNDER = [ ...@@ -86,7 +91,7 @@ _BINARY_OP_AND_DUNDER = [
] ]
def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: def _format_unary_op_str(op_str: str, arg1: tuple[str, ...] | str) -> str:
if isinstance(arg1, tuple): if isinstance(arg1, tuple):
arg1_entry, arg1_container = arg1 arg1_entry, arg1_container = arg1
return (f"{op_str.format(arg1_entry)} " return (f"{op_str.format(arg1_entry)} "
...@@ -96,20 +101,14 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: ...@@ -96,20 +101,14 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
def _format_binary_op_str(op_str: str, def _format_binary_op_str(op_str: str,
arg1: Union[Tuple[str, ...], str], arg1: tuple[str, str] | str,
arg2: Union[Tuple[str, ...], str]) -> str: arg2: tuple[str, str] | str) -> str:
if isinstance(arg1, tuple) and isinstance(arg2, tuple): if isinstance(arg1, tuple) and isinstance(arg2, tuple):
import sys
if sys.version_info >= (3, 10):
strict_arg = ", strict=__debug__"
else:
strict_arg = ""
arg1_entry, arg1_container = arg1 arg1_entry, arg1_container = arg1
arg2_entry, arg2_container = arg2 arg2_entry, arg2_container = arg2
return (f"{op_str.format(arg1_entry, arg2_entry)} " return (f"{op_str.format(arg1_entry, arg2_entry)} "
f"for {arg1_entry}, {arg2_entry} " f"for {arg1_entry}, {arg2_entry} "
f"in zip({arg1_container}, {arg2_container}{strict_arg})") f"in zip({arg1_container}, {arg2_container}, strict=__debug__)")
elif isinstance(arg1, tuple): elif isinstance(arg1, tuple):
arg1_entry, arg1_container = arg1 arg1_entry, arg1_container = arg1
...@@ -124,41 +123,71 @@ def _format_binary_op_str(op_str: str, ...@@ -124,41 +123,71 @@ def _format_binary_op_str(op_str: str,
return op_str.format(arg1, arg2) return op_str.format(arg1, arg2)
class NumpyObjectArrayMetaclass(type):
def __instancecheck__(cls, instance: Any) -> bool:
return isinstance(instance, np.ndarray) and instance.dtype == object
class NumpyObjectArray(metaclass=NumpyObjectArrayMetaclass):
pass
class ComplainingNumpyNonObjectArrayMetaclass(type):
def __instancecheck__(cls, instance: Any) -> bool:
if isinstance(instance, np.ndarray) and instance.dtype != object:
# Example usage site:
# https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149
# where normal is passed in by test_lfr_flux as a 'custom-made'
# numpy array of dtype float64.
warn(
"Broadcasting container against non-object numpy array. "
"This was never documented to work and will now stop working in "
"2025. Convert the array to an object array to preserve the "
"current semantics.", DeprecationWarning, stacklevel=3)
return True
else:
return False
class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMetaclass):
pass
def with_container_arithmetic( def with_container_arithmetic(
*, *,
bcast_number: bool = True, number_bcasts_across: bool | None = None,
_bcast_actx_array_type: Optional[bool] = None, bcasts_across_obj_array: bool | None = None,
bcast_obj_array: Optional[bool] = None, container_types_bcast_across: tuple[type, ...] | None = None,
bcast_numpy_array: bool = False, arithmetic: bool = True,
bcast_container_types: Optional[Tuple[type, ...]] = None, matmul: bool = False,
arithmetic: bool = True, bitwise: bool = False,
matmul: bool = False, shift: bool = False,
bitwise: bool = False, _cls_has_array_context_attr: bool | None = None,
shift: bool = False, eq_comparison: bool | None = None,
_cls_has_array_context_attr: bool = False, rel_comparison: bool | None = None,
eq_comparison: Optional[bool] = None,
rel_comparison: Optional[bool] = None) -> Callable[[type], type]: # deprecated:
bcast_number: bool | None = None,
bcast_obj_array: bool | None = None,
bcast_numpy_array: bool = False,
_bcast_actx_array_type: bool | None = None,
bcast_container_types: tuple[type, ...] | None = None,
) -> Callable[[type], type]:
"""A class decorator that implements built-in operators for array containers """A class decorator that implements built-in operators for array containers
by propagating the operations to the elements of the container. by propagating the operations to the elements of the container.
:arg bcast_number: If *True*, numbers broadcast over the container :arg number_bcasts_across: If *True*, numbers broadcast over the container
(with the container as the 'outer' structure). (with the container as the 'outer' structure).
:arg _bcast_actx_array_type: If *True*, instances of base array types of the :arg bcasts_across_obj_array: If *True*, this container will be broadcast
container's array context are broadcasted over the container. Can be across :mod:`numpy` object arrays
*True* only if the container has *_cls_has_array_context_attr* set. (with the object array as the 'outer' structure).
Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set, Add :class:`numpy.ndarray` to *container_types_bcast_across* to achieve
else *False*. the 'reverse' broadcasting.
:arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over :arg container_types_bcast_across: A sequence of container types that will broadcast
the container. (with the container as the 'inner' structure) across this container, with this container as the 'outer' structure.
:arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast
over the container. (with the container as the 'inner' structure)
If this is set to *True*, *bcast_obj_array* must also be *True*.
:arg bcast_container_types: A sequence of container types that will broadcast
over this container (with this container as the 'outer' structure).
:class:`numpy.ndarray` is permitted to be part of this sequence to :class:`numpy.ndarray` is permitted to be part of this sequence to
indicate that, in such broadcasting situations, this container should indicate that object arrays (and *only* object arrays) will be broadcast.
be the 'outer' structure. In this case, *bcast_obj_array* In this case, *bcasts_across_obj_array* must be *False*.
(and consequently *bcast_numpy_array*) must be *False*.
:arg arithmetic: Implement the conventional arithmetic operators, including :arg arithmetic: Implement the conventional arithmetic operators, including
``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as
:func:`abs`. :func:`abs`.
...@@ -172,6 +201,8 @@ def with_container_arithmetic( ...@@ -172,6 +201,8 @@ def with_container_arithmetic(
class has an ``array_context`` attribute. If so, and if :data:`__debug__` class has an ``array_context`` attribute. If so, and if :data:`__debug__`
is *True*, an additional check is performed in binary operators is *True*, an additional check is performed in binary operators
to ensure that both containers use the same array context. to ensure that both containers use the same array context.
If *None* (the default), this value is set based on whether the class
has an ``array_context`` attribute.
Consider this argument an unstable interface. It may disappear at any moment. Consider this argument an unstable interface. It may disappear at any moment.
Each operator class also includes the "reverse" operators if applicable. Each operator class also includes the "reverse" operators if applicable.
...@@ -198,13 +229,18 @@ def with_container_arithmetic( ...@@ -198,13 +229,18 @@ def with_container_arithmetic(
should nest "outside" :func:dataclass_array_container`. should nest "outside" :func:dataclass_array_container`.
""" """
# {{{ handle inputs # Hard-won design lessons:
#
if bcast_obj_array is None: # - Anything that special-cases np.ndarray by type is broken by design because:
raise TypeError("bcast_obj_array must be specified") # - np.ndarray is an array context array.
# - numpy object arrays can be array containers.
# Using NumpyObjectArray and NumpyNonObjectArray *may* be better?
# They're new, so there is no operational experience with them.
#
# - Broadcast rules are hard to change once established, particularly
# because one cannot grep for their use.
if rel_comparison is None: # {{{ handle inputs
raise TypeError("rel_comparison must be specified")
if rel_comparison and eq_comparison is None: if rel_comparison and eq_comparison is None:
eq_comparison = True eq_comparison = True
...@@ -212,37 +248,104 @@ def with_container_arithmetic( ...@@ -212,37 +248,104 @@ def with_container_arithmetic(
if eq_comparison is None: if eq_comparison is None:
raise TypeError("eq_comparison must be specified") raise TypeError("eq_comparison must be specified")
if not bcast_obj_array and bcast_numpy_array: # {{{ handle bcast_number
raise TypeError("bcast_obj_array must be set if bcast_numpy_array is")
if _bcast_actx_array_type is None: if bcast_number is not None:
if _cls_has_array_context_attr: if number_bcasts_across is not None:
_bcast_actx_array_type = bcast_number raise TypeError(
else: "may specify at most one of 'bcast_number' and "
_bcast_actx_array_type = False "'number_bcasts_across'")
warn("'bcast_number' is deprecated and will be unsupported from 2025. "
"Use 'number_bcasts_across', with equivalent meaning.",
DeprecationWarning, stacklevel=2)
number_bcasts_across = bcast_number
else:
if number_bcasts_across is None:
number_bcasts_across = True
del bcast_number
# }}}
# {{{ handle bcast_obj_array
if bcast_obj_array is not None:
if bcasts_across_obj_array is not None:
raise TypeError(
"may specify at most one of 'bcast_obj_array' and "
"'bcasts_across_obj_array'")
warn("'bcast_obj_array' is deprecated and will be unsupported from 2025. "
"Use 'bcasts_across_obj_array', with equivalent meaning.",
DeprecationWarning, stacklevel=2)
bcasts_across_obj_array = bcast_obj_array
else:
if bcasts_across_obj_array is None:
raise TypeError("bcasts_across_obj_array must be specified")
del bcast_obj_array
# }}}
# {{{ handle bcast_container_types
if bcast_container_types is not None:
if container_types_bcast_across is not None:
raise TypeError(
"may specify at most one of 'bcast_container_types' and "
"'container_types_bcast_across'")
warn("'bcast_container_types' is deprecated and will be unsupported from 2025. "
"Use 'container_types_bcast_across', with equivalent meaning.",
DeprecationWarning, stacklevel=2)
container_types_bcast_across = bcast_container_types
else: else:
if _bcast_actx_array_type and not _cls_has_array_context_attr: if container_types_bcast_across is None:
raise TypeError("_bcast_actx_array_type can be True only if " container_types_bcast_across = ()
"_cls_has_array_context_attr is set.")
del bcast_container_types
# }}}
if rel_comparison is None:
raise TypeError("rel_comparison must be specified")
if bcast_numpy_array:
warn("'bcast_numpy_array=True' is deprecated and will be unsupported"
" from 2025.", DeprecationWarning, stacklevel=2)
if _bcast_actx_array_type:
raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'"
" cannot be both set.")
if not bcasts_across_obj_array and bcast_numpy_array:
raise TypeError("bcast_obj_array must be set if bcast_numpy_array is")
if bcast_numpy_array: if bcast_numpy_array:
def numpy_pred(name: str) -> str: def numpy_pred(name: str) -> str:
return f"isinstance({name}, np.ndarray)" return f"is_numpy_array({name})"
elif bcast_obj_array: elif bcasts_across_obj_array:
def numpy_pred(name: str) -> str: def numpy_pred(name: str) -> str:
return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'" return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'"
else: else:
def numpy_pred(name: str) -> str: def numpy_pred(name: str) -> str:
return "False" # optimized away return "False" # optimized away
if bcast_container_types is None: if np.ndarray in container_types_bcast_across and bcasts_across_obj_array:
bcast_container_types = ()
bcast_container_types_count = len(bcast_container_types)
if np.ndarray in bcast_container_types and bcast_obj_array:
raise ValueError("If numpy.ndarray is part of bcast_container_types, " raise ValueError("If numpy.ndarray is part of bcast_container_types, "
"bcast_obj_array must be False.") "bcast_obj_array must be False.")
numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray]
container_types_bcast_across = tuple(
new_ct
for old_ct in container_types_bcast_across
for new_ct in
(numpy_check_types
if old_ct is np.ndarray
else [old_ct])
)
desired_op_classes = set() desired_op_classes = set()
if arithmetic: if arithmetic:
desired_op_classes.add(_OpClass.ARITHMETIC) desired_op_classes.add(_OpClass.ARITHMETIC)
...@@ -260,6 +363,64 @@ def with_container_arithmetic( ...@@ -260,6 +363,64 @@ def with_container_arithmetic(
# }}} # }}}
def wrap(cls: Any) -> Any: def wrap(cls: Any) -> Any:
if not hasattr(cls, "__array_ufunc__"):
warn(f"{cls} does not have __array_ufunc__ set. "
"This will cause numpy to attempt broadcasting, in a way that "
"is likely undesired. "
f"To avoid this, set __array_ufunc__ = None in {cls}.",
stacklevel=2)
cls_has_array_context_attr: bool | None = _cls_has_array_context_attr
bcast_actx_array_type: bool | None = _bcast_actx_array_type
if cls_has_array_context_attr is None and hasattr(cls, "array_context"):
raise TypeError(
f"{cls} has an 'array_context' attribute, but it does not "
"set '_cls_has_array_context_attr' to *True* when calling "
"with_container_arithmetic. This is being interpreted "
"as '.array_context' being permitted to fail "
"with an exception, which is no longer allowed. "
f"If {cls.__name__}.array_context will not fail, pass "
"'_cls_has_array_context_attr=True'. "
"If you do not want container arithmetic to make "
"use of the array context, set "
"'_cls_has_array_context_attr=False'.")
if bcast_actx_array_type is None:
if cls_has_array_context_attr:
if number_bcasts_across:
bcast_actx_array_type = cls_has_array_context_attr
else:
bcast_actx_array_type = False
else:
if bcast_actx_array_type and not cls_has_array_context_attr:
raise TypeError("_bcast_actx_array_type can be True only if "
"_cls_has_array_context_attr is set.")
if bcast_actx_array_type:
if _bcast_actx_array_type:
warn(
f"Broadcasting array context array types across {cls} "
"has been explicitly "
"enabled. As of 2025, this will stop working. "
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/pull/190. "
"To opt out now (and avoid this warning), "
"pass _bcast_actx_array_type=False. ",
DeprecationWarning, stacklevel=2)
else:
warn(
f"Broadcasting array context array types across {cls} "
"has been implicitly "
"enabled. As of 2025, this will no longer work. "
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/pull/190. "
"To opt out now (and avoid this warning), "
"pass _bcast_actx_array_type=False.",
DeprecationWarning, stacklevel=2)
if (not hasattr(cls, "_serialize_init_arrays_code") if (not hasattr(cls, "_serialize_init_arrays_code")
or not hasattr(cls, "_deserialize_init_arrays_code")): or not hasattr(cls, "_deserialize_init_arrays_code")):
raise TypeError(f"class '{cls.__name__}' must provide serialization " raise TypeError(f"class '{cls.__name__}' must provide serialization "
...@@ -270,42 +431,65 @@ def with_container_arithmetic( ...@@ -270,42 +431,65 @@ def with_container_arithmetic(
from pytools.codegen import CodeGenerator, Indentation from pytools.codegen import CodeGenerator, Indentation
gen = CodeGenerator() gen = CodeGenerator()
gen(""" gen(f"""
from numbers import Number from numbers import Number
import numpy as np import numpy as np
from arraycontext import ArrayContainer from arraycontext import ArrayContainer
from warnings import warn
def _raise_if_actx_none(actx): def _raise_if_actx_none(actx):
if actx is None: if actx is None:
raise ValueError("array containers with frozen arrays " raise ValueError("array containers with frozen arrays "
"cannot be operated upon") "cannot be operated upon")
return actx return actx
def is_numpy_array(arg):
if isinstance(arg, np.ndarray):
if arg.dtype != "O":
warn("Operand is a non-object numpy array, "
"and the broadcasting behavior of this array container "
"({cls}) "
"is influenced by this because of its use of "
"the deprecated bcast_numpy_array. This broadcasting "
"behavior will change in 2025. If you would like the "
"broadcasting behavior to stay the same, make sure "
"to convert the passed numpy array to an "
"object array.",
DeprecationWarning, stacklevel=3)
return True
else:
return False
""") """)
gen("") gen("")
if bcast_container_types: if container_types_bcast_across:
for i, bct in enumerate(bcast_container_types): for i, bct in enumerate(container_types_bcast_across):
gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}") gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}")
gen("") gen("")
outer_bcast_type_names = tuple([ container_type_names_bcast_across = tuple(
f"_bctype{i}" for i in range(bcast_container_types_count) f"_bctype{i}" for i in range(len(container_types_bcast_across)))
]) if number_bcasts_across:
if bcast_number: container_type_names_bcast_across += ("Number",)
outer_bcast_type_names += ("Number",)
def same_key(k1: T, k2: T) -> T: def same_key(k1: T, k2: T) -> T:
assert k1 == k2 assert k1 == k2
return k1 return k1
def tup_str(t: Tuple[str, ...]) -> str: def tup_str(t: tuple[str, ...]) -> str:
if not t: if not t:
return "()" return "()"
else: else:
return "(%s,)" % ", ".join(t) return "({},)".format(", ".join(t))
gen(f"cls._outer_bcast_types = {tup_str(container_type_names_bcast_across)}")
gen("cls._container_types_bcast_across = "
f"{tup_str(container_type_names_bcast_across)}")
gen(f"cls._outer_bcast_types = {tup_str(outer_bcast_type_names)}")
gen(f"cls._bcast_numpy_array = {bcast_numpy_array}") gen(f"cls._bcast_numpy_array = {bcast_numpy_array}")
gen(f"cls._bcast_obj_array = {bcast_obj_array}")
gen(f"cls._bcast_obj_array = {bcasts_across_obj_array}")
gen(f"cls._bcasts_across_obj_array = {bcasts_across_obj_array}")
gen("") gen("")
# {{{ unary operators # {{{ unary operators
...@@ -349,35 +533,43 @@ def with_container_arithmetic( ...@@ -349,35 +533,43 @@ def with_container_arithmetic(
continue continue
# {{{ "forward" binary operators
zip_init_args = cls._deserialize_init_arrays_code("arg1", { zip_init_args = cls._deserialize_init_arrays_code("arg1", {
same_key(key_arg1, key_arg2): same_key(key_arg1, key_arg2):
_format_binary_op_str(op_str, expr_arg1, expr_arg2) _format_binary_op_str(op_str, expr_arg1, expr_arg2)
for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip( for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip(
cls._serialize_init_arrays_code("arg1").items(), cls._serialize_init_arrays_code("arg1").items(),
cls._serialize_init_arrays_code("arg2").items()) cls._serialize_init_arrays_code("arg2").items(),
strict=True)
}) })
bcast_same_cls_init_args = cls._deserialize_init_arrays_code("arg1", { bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", {
key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
for key_arg1, expr_arg1 in for key_arg1, expr_arg1 in
cls._serialize_init_arrays_code("arg1").items() cls._serialize_init_arrays_code("arg1").items()
}) })
bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", {
key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2)
for key_arg2, expr_arg2 in
cls._serialize_init_arrays_code("arg2").items()
})
# {{{ "forward" binary operators
gen(f"def {fname}(arg1, arg2):") gen(f"def {fname}(arg1, arg2):")
with Indentation(gen): with Indentation(gen):
gen("if arg2.__class__ is cls:") gen("if arg2.__class__ is cls:")
with Indentation(gen): with Indentation(gen):
if __debug__ and _cls_has_array_context_attr: if __debug__ and cls_has_array_context_attr:
gen(""" gen("""
if arg1.array_context is not arg2.array_context: arg1_actx = arg1.array_context
arg2_actx = arg2.array_context
if arg1_actx is not arg2_actx:
msg = ("array contexts of both arguments " msg = ("array contexts of both arguments "
"must match") "must match")
if arg1.array_context is None: if arg1_actx is None:
raise ValueError(msg raise ValueError(msg
+ ": left operand is frozen " + ": left operand is frozen "
"(i.e. has no array context)") "(i.e. has no array context)")
elif arg2.array_context is None: elif arg2_actx is None:
raise ValueError(msg raise ValueError(msg
+ ": right operand is frozen " + ": right operand is frozen "
"(i.e. has no array context)") "(i.e. has no array context)")
...@@ -385,26 +577,41 @@ def with_container_arithmetic( ...@@ -385,26 +577,41 @@ def with_container_arithmetic(
raise ValueError(msg)""") raise ValueError(msg)""")
gen(f"return cls({zip_init_args})") gen(f"return cls({zip_init_args})")
if _bcast_actx_array_type: if bcast_actx_array_type:
if __debug__: if __debug__:
bcast_actx_ary_types: Tuple[str, ...] = ( bcast_actx_ary_types: tuple[str, ...] = (
"*_raise_if_actx_none(arg1.array_context).array_types",) "*_raise_if_actx_none("
"arg1.array_context).array_types",)
else: else:
bcast_actx_ary_types = ("*arg1.array_context.array_types",) bcast_actx_ary_types = (
"*arg1.array_context.array_types",)
else: else:
bcast_actx_ary_types = () bcast_actx_ary_types = ()
gen(f""" gen(f"""
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg2,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_same_cls_init_args})
if {numpy_pred("arg2")}: if {numpy_pred("arg2")}:
result = np.empty_like(arg2, dtype=object) result = np.empty_like(arg2, dtype=object)
for i in np.ndindex(arg2.shape): for i in np.ndindex(arg2.shape):
result[i] = {op_str.format("arg1", "arg2[i]")} result[i] = {op_str.format("arg1", "arg2[i]")}
return result return result
if {bool(container_type_names_bcast_across)}: # optimized away
if isinstance(arg2,
{tup_str(container_type_names_bcast_across
+ bcast_actx_ary_types)}):
if __debug__:
if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
warn("Broadcasting {cls} over array "
f"context array type {{type(arg2)}} is deprecated "
"and will no longer work in 2025. "
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/"
"pull/190. ",
DeprecationWarning, stacklevel=2)
return cls({bcast_init_args_arg1_is_outer})
return NotImplemented return NotImplemented
""") """)
gen(f"cls.__{dunder_name}__ = {fname}") gen(f"cls.__{dunder_name}__ = {fname}")
...@@ -416,19 +623,15 @@ def with_container_arithmetic( ...@@ -416,19 +623,15 @@ def with_container_arithmetic(
if reversible: if reversible:
fname = f"_{cls.__name__.lower()}_r{dunder_name}" fname = f"_{cls.__name__.lower()}_r{dunder_name}"
bcast_init_args = cls._deserialize_init_arrays_code("arg2", {
key_arg2: _format_binary_op_str( if bcast_actx_array_type:
op_str, "arg1", expr_arg2)
for key_arg2, expr_arg2 in
cls._serialize_init_arrays_code("arg2").items()
})
if _bcast_actx_array_type:
if __debug__: if __debug__:
bcast_actx_ary_types = ( bcast_actx_ary_types = (
"*_raise_if_actx_none(arg2.array_context).array_types",) "*_raise_if_actx_none("
"arg2.array_context).array_types",)
else: else:
bcast_actx_ary_types = ("*arg2.array_context.array_types",) bcast_actx_ary_types = (
"*arg2.array_context.array_types",)
else: else:
bcast_actx_ary_types = () bcast_actx_ary_types = ()
...@@ -436,16 +639,30 @@ def with_container_arithmetic( ...@@ -436,16 +639,30 @@ def with_container_arithmetic(
def {fname}(arg2, arg1): def {fname}(arg2, arg1):
# assert other.__cls__ is not cls # assert other.__cls__ is not cls
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg1,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_init_args})
if {numpy_pred("arg1")}: if {numpy_pred("arg1")}:
result = np.empty_like(arg1, dtype=object) result = np.empty_like(arg1, dtype=object)
for i in np.ndindex(arg1.shape): for i in np.ndindex(arg1.shape):
result[i] = {op_str.format("arg1[i]", "arg2")} result[i] = {op_str.format("arg1[i]", "arg2")}
return result return result
if {bool(container_type_names_bcast_across)}: # optimized away
if isinstance(arg1,
{tup_str(container_type_names_bcast_across
+ bcast_actx_ary_types)}):
if __debug__:
if isinstance(arg1,
{tup_str(bcast_actx_ary_types)}):
warn("Broadcasting {cls} over array "
f"context array type {{type(arg1)}} "
"is deprecated "
"and will no longer work in 2025."
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/"
"pull/190. ",
DeprecationWarning, stacklevel=2)
return cls({bcast_init_args_arg2_is_outer})
return NotImplemented return NotImplemented
cls.__r{dunder_name}__ = {fname}""") cls.__r{dunder_name}__ = {fname}""")
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
.. autofunction:: dataclass_array_container .. autofunction:: dataclass_array_container
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
...@@ -30,12 +31,28 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -30,12 +31,28 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from dataclasses import fields from collections.abc import Mapping, Sequence
from dataclasses import fields, is_dataclass
from typing import NamedTuple, Union, get_args, get_origin
from arraycontext.container import is_array_container_type from arraycontext.container import is_array_container_type
# {{{ dataclass containers # {{{ dataclass containers
class _Field(NamedTuple):
"""Small lookalike for :class:`dataclasses.Field`."""
init: bool
name: str
type: type
def is_array_type(tp: type) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)
def dataclass_array_container(cls: type) -> type: def dataclass_array_container(cls: type) -> type:
"""A class decorator that makes the class to which it is applied an """A class decorator that makes the class to which it is applied an
:class:`ArrayContainer` by registering appropriate implementations of :class:`ArrayContainer` by registering appropriate implementations of
...@@ -44,41 +61,140 @@ def dataclass_array_container(cls: type) -> type: ...@@ -44,41 +61,140 @@ def dataclass_array_container(cls: type) -> type:
Attributes that are not array containers are allowed. In order to decide Attributes that are not array containers are allowed. In order to decide
whether an attribute is an array container, the declared attribute type whether an attribute is an array container, the declared attribute type
is checked by the criteria from :func:`is_array_container_type`. is checked by the criteria from :func:`is_array_container_type`. This
includes some support for type annotations:
* a :class:`typing.Union` of array containers is considered an array container.
* other type annotations, e.g. :class:`typing.Optional`, are not considered
array containers, even if they wrap one.
.. note::
When type annotations are strings (e.g. because of
``from __future__ import annotations``),
this function relies on :func:`inspect.get_annotations`
(with ``eval_str=True``) to obtain type annotations. This
means that *cls* must live in a module that is importable.
""" """
from dataclasses import is_dataclass, Field
from types import GenericAlias, UnionType
assert is_dataclass(cls) assert is_dataclass(cls)
def is_array_field(f: Field) -> bool: def is_array_field(f: _Field) -> bool:
field_type = f.type
# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
#
# `Optional[ArrayContainer]` is not allowed, since `None` is not
# handled by `with_container_arithmetic`, which is the common case
# for current container usage. Other type annotations, e.g.
# `Tuple[Container, Container]`, are also not allowed, as they do not
# work with `with_container_arithmetic`.
#
# This is not set in stone, but mostly driven by current usage!
origin = get_origin(field_type)
# NOTE: `UnionType` is returned when using `Type1 | Type2`
if origin in (Union, UnionType):
if all(is_array_type(arg) for arg in get_args(field_type)):
return True
else:
raise TypeError(
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")
# NOTE: this should never happen due to using `inspect.get_annotations`
assert not isinstance(field_type, str)
if __debug__: if __debug__:
if not f.init: if not f.init:
raise ValueError( raise ValueError(
f"'init=False' field not allowed: '{f.name}'") f"Field with 'init=False' not allowed: '{f.name}'")
if isinstance(f.type, str): # NOTE:
raise TypeError( # * `GenericAlias` catches typed `list`, `tuple`, etc.
f"string annotation on field '{f.name}' not supported") # * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
# * `_SpecialForm` catches `Any`, `Literal`, etc.
from typing import _SpecialForm from typing import ( # type: ignore[attr-defined]
if isinstance(f.type, _SpecialForm): _BaseGenericAlias,
_SpecialForm,
)
if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm):
# NOTE: anything except a Union is not allowed
raise TypeError( raise TypeError(
f"typing annotation not supported on field '{f.name}': " f"Typing annotation not supported on field '{f.name}': "
f"'{f.type!r}'") f"'{field_type!r}'")
if not isinstance(f.type, type): if not isinstance(field_type, type):
raise TypeError( raise TypeError(
f"field '{f.name}' not an instance of 'type': " f"Field '{f.name}' not an instance of 'type': "
f"'{f.type!r}'") f"'{field_type!r}'")
return is_array_container_type(f.type) return is_array_type(field_type)
from pytools import partition from pytools import partition
array_fields, non_array_fields = partition(is_array_field, fields(cls))
array_fields = _get_annotated_fields(cls)
array_fields, non_array_fields = partition(is_array_field, array_fields)
if not array_fields: if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type " raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator") "in order to use the 'dataclass_array_container' decorator")
return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
def _get_annotated_fields(cls: type) -> Sequence[_Field]:
"""Get a list of fields in the class *cls* with evaluated types.
If any of the fields in *cls* have type annotations that are strings, e.g.
from using ``from __future__ import annotations``, this function evaluates
them using :func:`inspect.get_annotations`. Note that this requires the class
to live in a module that is importable.
:return: a list of fields.
"""
from inspect import get_annotations
result = []
cls_ann: Mapping[str, type] | None = None
for field in fields(cls):
field_type_or_str = field.type
if isinstance(field_type_or_str, str):
if cls_ann is None:
cls_ann = get_annotations(cls, eval_str=True)
field_type = cls_ann[field.name]
else:
field_type = field_type_or_str
result.append(_Field(init=field.init, name=field.name, type=field_type))
return result
def _inject_dataclass_serialization(
cls: type,
array_fields: Sequence[_Field],
non_array_fields: Sequence[_Field]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
This function modifies *cls* in place, so the returned value is the same
object with additional functionality.
:arg array_fields: fields of the given dataclass *cls* which are considered
array containers and should be serialized.
:arg non_array_fields: remaining fields of the dataclass *cls* which are
copied over from the template array in deserialization.
"""
assert is_dataclass(cls)
serialize_expr = ", ".join( serialize_expr = ", ".join(
f"({f.name!r}, ary.{f.name})" for f in array_fields) f"({f.name!r}, ary.{f.name})" for f in array_fields)
template_kwargs = ", ".join( template_kwargs = ", ".join(
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
.. autofunction:: rec_map_reduce_array_container .. autofunction:: rec_map_reduce_array_container
.. autofunction:: rec_multimap_reduce_array_container .. autofunction:: rec_multimap_reduce_array_container
.. autofunction:: stringify_array_container_tree
Traversing decorators Traversing decorators
~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: mapped_over_array_containers .. autofunction:: mapped_over_array_containers
...@@ -27,6 +29,7 @@ Flattening and unflattening ...@@ -27,6 +29,7 @@ Flattening and unflattening
~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: flatten .. autofunction:: flatten
.. autofunction:: unflatten .. autofunction:: unflatten
.. autofunction:: flat_size_and_dtype
Numpy conversion Numpy conversion
~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~
...@@ -38,6 +41,11 @@ Algebraic operations ...@@ -38,6 +41,11 @@ Algebraic operations
.. autofunction:: outer .. autofunction:: outer
""" """
from __future__ import annotations
from arraycontext.container.arithmetic import NumpyObjectArray
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
""" """
...@@ -62,24 +70,39 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -62,24 +70,39 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from typing import Any, Callable, Iterable, List, Optional, Union, Tuple from collections.abc import Callable, Iterable
from functools import update_wrapper, partial, singledispatch from functools import partial, singledispatch, update_wrapper
from typing import Any, cast
from warnings import warn
import numpy as np import numpy as np
from arraycontext.context import ArrayContext, DeviceArray, _ScalarLike
from arraycontext.container import ( from arraycontext.container import (
ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError, ArrayContainer,
serialize_container, deserialize_container) NotAnArrayContainerError,
SerializationKey,
deserialize_container,
get_container_context_recursively_opt,
serialize_container,
)
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerT,
ArrayT,
ScalarLike,
)
# {{{ array container traversal helpers # {{{ array container traversal helpers
def _map_array_container_impl( def _map_array_container_impl(
f: Callable[[Any], Any], f: Callable[[ArrayOrContainer], ArrayOrContainer],
ary: ArrayOrContainerT, *, ary: ArrayOrContainer, *,
leaf_cls: Optional[type] = None, leaf_cls: type | None = None,
recursive: bool = False) -> ArrayOrContainerT: recursive: bool = False) -> ArrayOrContainer:
"""Helper for :func:`rec_map_array_container`. """Helper for :func:`rec_map_array_container`.
:param leaf_cls: class on which we call *f* directly. This is mostly :param leaf_cls: class on which we call *f* directly. This is mostly
...@@ -87,16 +110,16 @@ def _map_array_container_impl( ...@@ -87,16 +110,16 @@ def _map_array_container_impl(
specific container classes. By default, the recursion is stopped when specific container classes. By default, the recursion is stopped when
a non-:class:`ArrayContainer` class is encountered. a non-:class:`ArrayContainer` class is encountered.
""" """
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: def rec(ary_: ArrayOrContainer) -> ArrayOrContainer:
if type(_ary) is leaf_cls: # type(ary) is never None if type(ary_) is leaf_cls: # type(ary) is never None
return f(_ary) return f(ary_)
try: try:
iterable = serialize_container(_ary) iterable = serialize_container(ary_)
except NotAnArrayContainerError: except NotAnArrayContainerError:
return f(_ary) return f(ary_)
else: else:
return deserialize_container(_ary, [ return deserialize_container(ary_, [
(key, frec(subary)) for key, subary in iterable (key, frec(subary)) for key, subary in iterable
]) ])
...@@ -107,9 +130,10 @@ def _map_array_container_impl( ...@@ -107,9 +130,10 @@ def _map_array_container_impl(
def _multimap_array_container_impl( def _multimap_array_container_impl(
f: Callable[..., Any], f: Callable[..., Any],
*args: Any, *args: Any,
reduce_func: Callable[[ContainerT, Iterable[Tuple[Any, Any]]], Any] = None, reduce_func: (
leaf_cls: Optional[type] = None, Callable[[ArrayContainer, Iterable[tuple[Any, Any]]], Any] | None) = None,
recursive: bool = False) -> ArrayOrContainerT: leaf_cls: type | None = None,
recursive: bool = False) -> ArrayOrContainer:
"""Helper for :func:`rec_multimap_array_container`. """Helper for :func:`rec_multimap_array_container`.
:param leaf_cls: class on which we call *f* directly. This is mostly :param leaf_cls: class on which we call *f* directly. This is mostly
...@@ -120,31 +144,30 @@ def _multimap_array_container_impl( ...@@ -120,31 +144,30 @@ def _multimap_array_container_impl(
# {{{ recursive traversal # {{{ recursive traversal
def rec(*_args: Any) -> Any: def rec(*args_: Any) -> Any:
template_ary = _args[container_indices[0]] template_ary = args_[container_indices[0]]
if type(template_ary) is leaf_cls: if type(template_ary) is leaf_cls:
return f(*_args) return f(*args_)
try: try:
iterable_template = serialize_container(template_ary) iterable_template = serialize_container(template_ary)
except NotAnArrayContainerError: except NotAnArrayContainerError:
return f(*_args) return f(*args_)
else:
pass
assert all( assert all(
type(_args[i]) is type(template_ary) for i in container_indices[1:] type(args_[i]) is type(template_ary) for i in container_indices[1:]
), f"expected type '{type(template_ary).__name__}'" ), f"expected type '{type(template_ary).__name__}'"
result = [] result = []
new_args = list(_args) new_args = list(args_)
for subarys in zip( for subarys in zip(
iterable_template, iterable_template,
*[serialize_container(_args[i]) for i in container_indices[1:]] *[serialize_container(args_[i]) for i in container_indices[1:]],
strict=True
): ):
key = None key = None
for i, (subkey, subary) in zip(container_indices, subarys): for i, (subkey, subary) in zip(container_indices, subarys, strict=True):
if key is None: if key is None:
key = subkey key = subkey
else: else:
...@@ -152,15 +175,15 @@ def _multimap_array_container_impl( ...@@ -152,15 +175,15 @@ def _multimap_array_container_impl(
new_args[i] = subary new_args[i] = subary
result.append((key, frec(*new_args))) # type: ignore[operator] result.append((key, frec(*new_args)))
return process_container(template_ary, result) # type: ignore[operator] return process_container(template_ary, result)
# }}} # }}}
# {{{ find all containers in the argument list # {{{ find all containers in the argument list
container_indices: List[int] = [] container_indices: list[int] = []
for i, arg in enumerate(args): for i, arg in enumerate(args):
if type(arg) is leaf_cls: if type(arg) is leaf_cls:
...@@ -195,7 +218,7 @@ def _multimap_array_container_impl( ...@@ -195,7 +218,7 @@ def _multimap_array_container_impl(
return f(*new_args) return f(*new_args)
update_wrapper(wrapper, f) update_wrapper(wrapper, f)
template_ary: ContainerT = args[container_indices[0]] template_ary: ArrayContainer = args[container_indices[0]]
return _map_array_container_impl( return _map_array_container_impl(
wrapper, template_ary, wrapper, template_ary,
leaf_cls=leaf_cls, recursive=recursive) leaf_cls=leaf_cls, recursive=recursive)
...@@ -216,9 +239,33 @@ def _multimap_array_container_impl( ...@@ -216,9 +239,33 @@ def _multimap_array_container_impl(
# {{{ array container traversal # {{{ array container traversal
def stringify_array_container_tree(ary: ArrayOrContainer) -> str:
"""
:returns: a string for an ASCII tree representation of the array container,
similar to `asciitree <https://github.com/mbr/asciitree>`__.
"""
def rec(lines: list[str], ary_: ArrayOrContainerT, level: int) -> None:
try:
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
pass
else:
for key, subary in iterable:
key = f"{key} ({type(subary).__name__})"
indent = "" if level == 0 else f" | {' ' * 4 * (level - 1)}"
lines.append(f"{indent} +-- {key}")
rec(lines, subary, level + 1)
lines = [f"root ({type(ary).__name__})"]
rec(lines, ary, 0)
return "\n".join(lines)
def map_array_container( def map_array_container(
f: Callable[[Any], Any], f: Callable[[Any], Any],
ary: ArrayOrContainerT) -> ArrayOrContainerT: ary: ArrayOrContainer) -> ArrayOrContainer:
r"""Applies *f* to all components of an :class:`ArrayContainer`. r"""Applies *f* to all components of an :class:`ArrayContainer`.
Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but
...@@ -256,8 +303,8 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: ...@@ -256,8 +303,8 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
def rec_map_array_container( def rec_map_array_container(
f: Callable[[Any], Any], f: Callable[[Any], Any],
ary: ArrayOrContainerT, ary: ArrayOrContainer,
leaf_class: Optional[type] = None) -> ArrayOrContainerT: leaf_class: type | None = None) -> ArrayOrContainer:
r"""Applies *f* recursively to an :class:`ArrayContainer`. r"""Applies *f* recursively to an :class:`ArrayContainer`.
For a non-recursive version see :func:`map_array_container`. For a non-recursive version see :func:`map_array_container`.
...@@ -269,15 +316,15 @@ def rec_map_array_container( ...@@ -269,15 +316,15 @@ def rec_map_array_container(
def mapped_over_array_containers( def mapped_over_array_containers(
f: Optional[Callable[[Any], Any]] = None, f: Callable[[ArrayOrContainer], ArrayOrContainer] | None = None,
leaf_class: Optional[type] = None) -> Union[ leaf_class: type | None = None) -> (
Callable[[ArrayOrContainerT], ArrayOrContainerT], Callable[[ArrayOrContainer], ArrayOrContainer]
Callable[ | Callable[
[Callable[[Any], Any]], [Callable[[Any], Any]],
Callable[[ArrayOrContainerT], ArrayOrContainerT]]]: Callable[[ArrayOrContainer], ArrayOrContainer]]):
"""Decorator around :func:`rec_map_array_container`.""" """Decorator around :func:`rec_map_array_container`."""
def decorator(g: Callable[[Any], Any]) -> Callable[ def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[
[ArrayOrContainerT], ArrayOrContainerT]: [ArrayOrContainer], ArrayOrContainer]:
wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class) wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class)
update_wrapper(wrapper, g) update_wrapper(wrapper, g)
return wrapper return wrapper
...@@ -290,7 +337,7 @@ def mapped_over_array_containers( ...@@ -290,7 +337,7 @@ def mapped_over_array_containers(
def rec_multimap_array_container( def rec_multimap_array_container(
f: Callable[..., Any], f: Callable[..., Any],
*args: Any, *args: Any,
leaf_class: Optional[type] = None) -> Any: leaf_class: type | None = None) -> Any:
r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s. r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
For a non-recursive version see :func:`multimap_array_container`. For a non-recursive version see :func:`multimap_array_container`.
...@@ -303,10 +350,10 @@ def rec_multimap_array_container( ...@@ -303,10 +350,10 @@ def rec_multimap_array_container(
def multimapped_over_array_containers( def multimapped_over_array_containers(
f: Optional[Callable[..., Any]] = None, f: Callable[..., Any] | None = None,
leaf_class: Optional[type] = None) -> Union[ leaf_class: type | None = None) -> (
Callable[..., Any], Callable[..., Any]
Callable[[Callable[..., Any]], Callable[..., Any]]]: | Callable[[Callable[..., Any]], Callable[..., Any]]):
"""Decorator around :func:`rec_multimap_array_container`.""" """Decorator around :func:`rec_multimap_array_container`."""
def decorator(g: Callable[..., Any]) -> Callable[..., Any]: def decorator(g: Callable[..., Any]) -> Callable[..., Any]:
# can't use functools.partial, because its result is insufficiently # can't use functools.partial, because its result is insufficiently
...@@ -328,9 +375,9 @@ def multimapped_over_array_containers( ...@@ -328,9 +375,9 @@ def multimapped_over_array_containers(
def keyed_map_array_container( def keyed_map_array_container(
f: Callable[ f: Callable[
[Any, ArrayOrContainerT], [SerializationKey, ArrayOrContainer],
ArrayOrContainerT], ArrayOrContainer],
ary: ArrayOrContainerT) -> ArrayOrContainerT: ary: ArrayOrContainer) -> ArrayOrContainer:
r"""Applies *f* to all components of an :class:`ArrayContainer`. r"""Applies *f* to all components of an :class:`ArrayContainer`.
Works similarly to :func:`map_array_container`, but *f* also takes an Works similarly to :func:`map_array_container`, but *f* also takes an
...@@ -343,9 +390,9 @@ def keyed_map_array_container( ...@@ -343,9 +390,9 @@ def keyed_map_array_container(
""" """
try: try:
iterable = serialize_container(ary) iterable = serialize_container(ary)
except NotAnArrayContainerError: except NotAnArrayContainerError as err:
raise ValueError( raise ValueError(
f"Non-array container type has no key: {type(ary).__name__}") f"Non-array container type has no key: {type(ary).__name__}") from err
else: else:
return deserialize_container(ary, [ return deserialize_container(ary, [
(key, f(key, subary)) for key, subary in iterable (key, f(key, subary)) for key, subary in iterable
...@@ -353,8 +400,8 @@ def keyed_map_array_container( ...@@ -353,8 +400,8 @@ def keyed_map_array_container(
def rec_keyed_map_array_container( def rec_keyed_map_array_container(
f: Callable[[Tuple[Any, ...], ArrayT], ArrayT], f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
ary: ArrayOrContainerT) -> ArrayOrContainerT: ary: ArrayOrContainer) -> ArrayOrContainer:
""" """
Works similarly to :func:`rec_map_array_container`, except that *f* also Works similarly to :func:`rec_map_array_container`, except that *f* also
takes in a traversal path to the leaf array. The traversal path argument is takes in a traversal path to the leaf array. The traversal path argument is
...@@ -362,15 +409,15 @@ def rec_keyed_map_array_container( ...@@ -362,15 +409,15 @@ def rec_keyed_map_array_container(
the current array. the current array.
""" """
def rec(keys: Tuple[Union[str, int], ...], def rec(keys: tuple[SerializationKey, ...],
_ary: ArrayOrContainerT) -> ArrayOrContainerT: ary_: ArrayOrContainerT) -> ArrayOrContainerT:
try: try:
iterable = serialize_container(_ary) iterable = serialize_container(ary_)
except NotAnArrayContainerError: except NotAnArrayContainerError:
return f(keys, _ary) return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_)))
else: else:
return deserialize_container(_ary, [ return deserialize_container(ary_, [
(key, rec(keys + (key,), subary)) for key, subary in iterable (key, rec((*keys, key), subary)) for key, subary in iterable
]) ])
return rec((), ary) return rec((), ary)
...@@ -383,7 +430,7 @@ def rec_keyed_map_array_container( ...@@ -383,7 +430,7 @@ def rec_keyed_map_array_container(
def map_reduce_array_container( def map_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any], reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[[Any], Any], map_func: Callable[[Any], Any],
ary: ArrayOrContainerT) -> "DeviceArray": ary: ArrayOrContainerT) -> Array:
"""Perform a map-reduce over array containers. """Perform a map-reduce over array containers.
:param reduce_func: callable used to reduce over the components of *ary* :param reduce_func: callable used to reduce over the components of *ary*
...@@ -406,7 +453,7 @@ def map_reduce_array_container( ...@@ -406,7 +453,7 @@ def map_reduce_array_container(
def multimap_reduce_array_container( def multimap_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any], reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[..., Any], map_func: Callable[..., Any],
*args: Any) -> "DeviceArray": *args: Any) -> ArrayOrContainer:
r"""Perform a map-reduce over multiple array containers. r"""Perform a map-reduce over multiple array containers.
:param reduce_func: callable used to reduce over the components of any :param reduce_func: callable used to reduce over the components of any
...@@ -418,7 +465,9 @@ def multimap_reduce_array_container( ...@@ -418,7 +465,9 @@ def multimap_reduce_array_container(
""" """
# NOTE: this wrapper matches the signature of `deserialize_container` # NOTE: this wrapper matches the signature of `deserialize_container`
# to make plugging into `_multimap_array_container_impl` easier # to make plugging into `_multimap_array_container_impl` easier
def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any: def _reduce_wrapper(
ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]]
) -> Array:
return reduce_func([subary for _, subary in iterable]) return reduce_func([subary for _, subary in iterable])
return _multimap_array_container_impl( return _multimap_array_container_impl(
...@@ -429,8 +478,8 @@ def multimap_reduce_array_container( ...@@ -429,8 +478,8 @@ def multimap_reduce_array_container(
def rec_map_reduce_array_container( def rec_map_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any], reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[[Any], Any], map_func: Callable[[Any], Any],
ary: ArrayOrContainerT, ary: ArrayOrContainer,
leaf_class: Optional[type] = None) -> "DeviceArray": leaf_class: type | None = None) -> ArrayOrContainer:
"""Perform a map-reduce over array containers recursively. """Perform a map-reduce over array containers recursively.
:param reduce_func: callable used to reduce over the components of *ary* :param reduce_func: callable used to reduce over the components of *ary*
...@@ -468,14 +517,14 @@ def rec_map_reduce_array_container( ...@@ -468,14 +517,14 @@ def rec_map_reduce_array_container(
or any other such traversal. or any other such traversal.
""" """
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: def rec(ary_: ArrayOrContainerT) -> ArrayOrContainerT:
if type(_ary) is leaf_class: if type(ary_) is leaf_class:
return map_func(_ary) return map_func(ary_)
else: else:
try: try:
iterable = serialize_container(_ary) iterable = serialize_container(ary_)
except NotAnArrayContainerError: except NotAnArrayContainerError:
return map_func(_ary) return map_func(ary_)
else: else:
return reduce_func([ return reduce_func([
rec(subary) for _, subary in iterable rec(subary) for _, subary in iterable
...@@ -488,7 +537,7 @@ def rec_multimap_reduce_array_container( ...@@ -488,7 +537,7 @@ def rec_multimap_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any], reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[..., Any], map_func: Callable[..., Any],
*args: Any, *args: Any,
leaf_class: Optional[type] = None) -> "DeviceArray": leaf_class: type | None = None) -> ArrayOrContainer:
r"""Perform a map-reduce over multiple array containers recursively. r"""Perform a map-reduce over multiple array containers recursively.
:param reduce_func: callable used to reduce over the components of any :param reduce_func: callable used to reduce over the components of any
...@@ -506,7 +555,8 @@ def rec_multimap_reduce_array_container( ...@@ -506,7 +555,8 @@ def rec_multimap_reduce_array_container(
""" """
# NOTE: this wrapper matches the signature of `deserialize_container` # NOTE: this wrapper matches the signature of `deserialize_container`
# to make plugging into `_multimap_array_container_impl` easier # to make plugging into `_multimap_array_container_impl` easier
def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any: def _reduce_wrapper(
ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]]) -> Any:
return reduce_func([subary for _, subary in iterable]) return reduce_func([subary for _, subary in iterable])
return _multimap_array_container_impl( return _multimap_array_container_impl(
...@@ -518,10 +568,9 @@ def rec_multimap_reduce_array_container( ...@@ -518,10 +568,9 @@ def rec_multimap_reduce_array_container(
# {{{ freeze/thaw # {{{ freeze/thaw
@singledispatch
def freeze( def freeze(
ary: ArrayOrContainerT, ary: ArrayOrContainerT,
actx: Optional[ArrayContext] = None) -> ArrayOrContainerT: actx: ArrayContext | None = None) -> ArrayOrContainerT:
r"""Freezes recursively by going through all components of the r"""Freezes recursively by going through all components of the
:class:`ArrayContainer` *ary*. :class:`ArrayContainer` *ary*.
...@@ -532,23 +581,33 @@ def freeze( ...@@ -532,23 +581,33 @@ def freeze(
See :meth:`ArrayContext.thaw`. See :meth:`ArrayContext.thaw`.
""" """
try:
iterable = serialize_container(ary) if actx is None:
except NotAnArrayContainerError: warn("Calling freeze(ary) without specifying actx is deprecated, explicitly"
if actx is None: " call actx.freeze(ary) instead. This will stop working in 2023.",
raise TypeError( DeprecationWarning, stacklevel=2)
f"cannot freeze arrays of type {type(ary).__name__} "
"when actx is not supplied. Try calling actx.freeze " actx = get_container_context_recursively_opt(ary)
"directly or supplying an array context")
else:
return actx.freeze(ary)
else: else:
return deserialize_container(ary, [ warn("Calling freeze(ary, actx) is deprecated, call actx.freeze(ary)"
(key, freeze(subary, actx=actx)) for key, subary in iterable " instead. This will stop working in 2023.",
]) DeprecationWarning, stacklevel=2)
if __debug__:
rec_actx = get_container_context_recursively_opt(ary)
if (rec_actx is not None) and (rec_actx is not actx):
raise ValueError("Supplied array context does not agree with"
" the one obtained by traversing 'ary'.")
if actx is None:
raise TypeError(
f"cannot freeze arrays of type {type(ary).__name__} "
"when actx is not supplied. Try calling actx.freeze "
"directly or supplying an array context")
return actx.freeze(ary)
@singledispatch
def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
r"""Thaws recursively by going through all components of the r"""Thaws recursively by going through all components of the
:class:`ArrayContainer` *ary*. :class:`ArrayContainer` *ary*.
...@@ -569,14 +628,41 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: ...@@ -569,14 +628,41 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
in :mod:`meshmode`. This was necessary because in :mod:`meshmode`. This was necessary because
:func:`~functools.singledispatch` only dispatches on the first argument. :func:`~functools.singledispatch` only dispatches on the first argument.
""" """
warn("Calling thaw(ary, actx) is deprecated, call actx.thaw(ary) instead."
" This will stop working in 2023.",
DeprecationWarning, stacklevel=2)
if __debug__:
rec_actx = get_container_context_recursively_opt(ary)
if rec_actx is not None:
raise ValueError("cannot thaw a container that already has an array"
" context.")
return actx.thaw(ary)
# }}}
# {{{ with_array_context
@singledispatch
def with_array_context(ary: ArrayOrContainerT,
actx: ArrayContext | None) -> ArrayOrContainerT:
"""
Recursively associates *actx* to all the components of *ary*.
Array container types may use :func:`functools.singledispatch` ``.register``
to register container-specific implementations. See `this issue
<https://github.com/inducer/arraycontext/issues/162>`__ for discussion of
the future of this functionality.
"""
try: try:
iterable = serialize_container(ary) iterable = serialize_container(ary)
except NotAnArrayContainerError: except NotAnArrayContainerError:
return actx.thaw(ary) return ary
else: else:
return deserialize_container(ary, [ return deserialize_container(ary, [(key, with_array_context(subary, actx))
(key, thaw(subary, actx)) for key, subary in iterable for key, subary in iterable])
])
# }}} # }}}
...@@ -584,8 +670,8 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: ...@@ -584,8 +670,8 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
# {{{ flatten / unflatten # {{{ flatten / unflatten
def flatten( def flatten(
ary: ArrayOrContainerT, actx: ArrayContext, *, ary: ArrayOrContainer, actx: ArrayContext, *,
leaf_class: Optional[type] = None, leaf_class: type | None = None,
) -> Any: ) -> Any:
"""Convert all arrays in the :class:`~arraycontext.ArrayContainer` """Convert all arrays in the :class:`~arraycontext.ArrayContainer`
into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`. into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`.
...@@ -607,32 +693,34 @@ def flatten( ...@@ -607,32 +693,34 @@ def flatten(
""" """
common_dtype = None common_dtype = None
def _flatten(subary: ArrayOrContainerT) -> List[Any]: def _flatten(subary: ArrayOrContainer) -> list[Array]:
nonlocal common_dtype nonlocal common_dtype
try: try:
iterable = serialize_container(subary) iterable = serialize_container(subary)
except NotAnArrayContainerError: except NotAnArrayContainerError:
subary_c = cast(Array, subary)
if common_dtype is None: if common_dtype is None:
common_dtype = subary.dtype common_dtype = subary_c.dtype
if subary.dtype != common_dtype: if subary_c.dtype != common_dtype:
raise ValueError("arrays in container have different dtypes: " raise ValueError("arrays in container have different dtypes: "
f"got {subary.dtype}, expected {common_dtype}") f"got {subary_c.dtype}, expected {common_dtype}") from None
try: try:
flat_subary = actx.np.ravel(subary, order="C") flat_subary = actx.np.ravel(subary_c, order="C")
except ValueError as exc: except ValueError as exc:
# NOTE: we can't do much if the array context fails to ravel, # NOTE: we can't do much if the array context fails to ravel,
# since it is the one responsible for the actual memory layout # since it is the one responsible for the actual memory layout
if hasattr(subary, "strides"): if hasattr(subary_c, "strides"):
strides_msg = f" and strides {subary.strides}" strides_msg = f" and strides {subary_c.strides}"
else: else:
strides_msg = "" strides_msg = ""
raise NotImplementedError( raise NotImplementedError(
f"'{type(actx).__name__}.np.ravel' failed to reshape " f"'{type(actx).__name__}.np.ravel' failed to reshape "
f"an array with shape {subary.shape}{strides_msg}. " f"an array with shape {subary_c.shape}{strides_msg}. "
"This functionality needs to be implemented by the " "This functionality needs to be implemented by the "
"array context.") from exc "array context.") from exc
...@@ -644,7 +732,7 @@ def flatten( ...@@ -644,7 +732,7 @@ def flatten(
return result return result
def _flatten_without_leaf_class(subary: ArrayOrContainerT) -> Any: def _flatten_without_leaf_class(subary: ArrayOrContainer) -> Any:
result = _flatten(subary) result = _flatten(subary)
if len(result) == 1: if len(result) == 1:
...@@ -652,7 +740,7 @@ def flatten( ...@@ -652,7 +740,7 @@ def flatten(
else: else:
return actx.np.concatenate(result) return actx.np.concatenate(result)
def _flatten_with_leaf_class(subary: ArrayOrContainerT) -> Any: def _flatten_with_leaf_class(subary: ArrayOrContainer) -> Any:
if type(subary) is leaf_class: if type(subary) is leaf_class:
return _flatten_without_leaf_class(subary) return _flatten_without_leaf_class(subary)
...@@ -673,7 +761,7 @@ def flatten( ...@@ -673,7 +761,7 @@ def flatten(
def unflatten( def unflatten(
template: ArrayOrContainerT, ary: Any, template: ArrayOrContainerT, ary: Array,
actx: ArrayContext, *, actx: ArrayContext, *,
strict: bool = True) -> ArrayOrContainerT: strict: bool = True) -> ArrayOrContainerT:
"""Unflatten an array *ary* produced by :func:`flatten` back into an """Unflatten an array *ary* produced by :func:`flatten` back into an
...@@ -692,46 +780,49 @@ def unflatten( ...@@ -692,46 +780,49 @@ def unflatten(
offset = 0 offset = 0
common_dtype = None common_dtype = None
def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT: def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
nonlocal offset, common_dtype nonlocal offset, common_dtype
try: try:
iterable = serialize_container(template_subary) iterable = serialize_container(template_subary)
except NotAnArrayContainerError: except NotAnArrayContainerError:
template_subary_c = cast(Array, template_subary)
# {{{ validate subary # {{{ validate subary
if (offset + template_subary.size) > ary.size: if (offset + template_subary_c.size) > ary.size:
raise ValueError("'template' and 'ary' sizes do not match: " raise ValueError("'template' and 'ary' sizes do not match: "
"'template' is too large") "'template' is too large") from None
if strict: if strict:
if template_subary.dtype != ary.dtype: if template_subary_c.dtype != ary.dtype:
raise ValueError("'template' dtype does not match 'ary': " raise ValueError("'template' dtype does not match 'ary': "
f"got {template_subary.dtype}, expected {ary.dtype}") f"got {template_subary_c.dtype}, expected {ary.dtype}"
) from None
else: else:
# NOTE: still require that *template* has a uniform dtype # NOTE: still require that *template* has a uniform dtype
if common_dtype is None: if common_dtype is None:
common_dtype = template_subary.dtype common_dtype = template_subary_c.dtype
else: else:
if common_dtype != template_subary.dtype: if common_dtype != template_subary_c.dtype:
raise ValueError("arrays in 'template' have different " raise ValueError("arrays in 'template' have different "
f"dtypes: got {template_subary.dtype}, but " f"dtypes: got {template_subary_c.dtype}, but "
f"expected {common_dtype}.") f"expected {common_dtype}.") from None
# }}} # }}}
# {{{ reshape # {{{ reshape
flat_subary = ary[offset:offset + template_subary.size] flat_subary = ary[offset:offset + template_subary_c.size]
try: try:
subary = actx.np.reshape(flat_subary, subary = actx.np.reshape(flat_subary,
template_subary.shape, order="C") template_subary_c.shape, order="C")
except ValueError as exc: except ValueError as exc:
# NOTE: we can't do much if the array context fails to reshape, # NOTE: we can't do much if the array context fails to reshape,
# since it is the one responsible for the actual memory layout # since it is the one responsible for the actual memory layout
raise NotImplementedError( raise NotImplementedError(
f"'{type(actx).__name__}.np.reshape' failed to reshape " f"'{type(actx).__name__}.np.reshape' failed to reshape "
f"the flat array into shape {template_subary.shape}. " f"the flat array into shape {template_subary_c.shape}. "
"This functionality needs to be implemented by the " "This functionality needs to be implemented by the "
"array context.") from exc "array context.") from exc
...@@ -739,27 +830,34 @@ def unflatten( ...@@ -739,27 +830,34 @@ def unflatten(
# {{{ check strides # {{{ check strides
if strict and hasattr(template_subary, "strides"): if strict and hasattr(template_subary_c, "strides"): # noqa: SIM102
if template_subary.strides != subary.strides: # Checking strides for 0 sized arrays is ill-defined
# since they cannot be indexed
if (
# Mypy has a point: nobody promised a .strides attribute.
template_subary_c.strides != subary.strides
and template_subary_c.size != 0
):
raise ValueError( raise ValueError(
# Mypy has a point: nobody promised a .strides attribute.
f"strides do not match template: got {subary.strides}, " f"strides do not match template: got {subary.strides}, "
f"expected {template_subary.strides}") f"expected {template_subary_c.strides}") from None
# }}} # }}}
offset += template_subary.size offset += template_subary_c.size
return subary return subary
else: else:
return deserialize_container(template_subary, [ return deserialize_container(template_subary, [
(key, _unflatten(isubary)) for key, isubary in iterable (key, _unflatten(isubary)) for key, isubary in iterable
]) ])
if not isinstance(ary, actx.array_types): if not isinstance(ary, actx.array_types):
raise TypeError("'ary' does not have a type supported by the provided " raise TypeError("'ary' does not have a type supported by the provided "
f"array context: got '{type(ary).__name__}', expected one of " f"array context: got '{type(ary).__name__}', expected one of "
f"{actx.array_types}") f"{actx.array_types}")
if ary.ndim != 1: if len(ary.shape) != 1:
raise ValueError( raise ValueError(
"only one dimensional arrays can be unflattened: " "only one dimensional arrays can be unflattened: "
f"'ary' has shape {ary.shape}") f"'ary' has shape {ary.shape}")
...@@ -769,7 +867,39 @@ def unflatten( ...@@ -769,7 +867,39 @@ def unflatten(
raise ValueError("'template' and 'ary' sizes do not match: " raise ValueError("'template' and 'ary' sizes do not match: "
"'ary' is too large") "'ary' is too large")
return result return cast(ArrayOrContainerT, result)
def flat_size_and_dtype(
ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]:
"""
:returns: a tuple ``(size, dtype)`` that would be the length and
:class:`numpy.dtype` of the one-dimensional array returned by
:func:`flatten`.
"""
common_dtype = None
def _flat_size(subary: ArrayOrContainer) -> int:
nonlocal common_dtype
try:
iterable = serialize_container(subary)
except NotAnArrayContainerError:
subary_c = cast(Array, subary)
if common_dtype is None:
common_dtype = subary_c.dtype
if subary_c.dtype != common_dtype:
raise ValueError("arrays in container have different dtypes: "
f"got {subary_c.dtype}, expected {common_dtype}") from None
return subary_c.size
else:
return sum(_flat_size(isubary) for _, isubary in iterable)
size = _flat_size(ary)
return size, common_dtype
# }}} # }}}
...@@ -777,38 +907,31 @@ def unflatten( ...@@ -777,38 +907,31 @@ def unflatten(
# {{{ numpy conversion # {{{ numpy conversion
def from_numpy( def from_numpy(
ary: Union[np.ndarray, _ScalarLike], ary: np.ndarray | ScalarLike,
actx: ArrayContext) -> ArrayOrContainerT: actx: ArrayContext) -> ArrayOrContainerOrScalar:
"""Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer`
to the base array type of :class:`~arraycontext.ArrayContext`. to the base array type of :class:`~arraycontext.ArrayContext`.
The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`.
""" """
def _from_numpy_with_check(subary: Union[np.ndarray, _ScalarLike]) \ warn("Calling from_numpy(ary, actx) is deprecated, call actx.from_numpy(ary)"
-> ArrayOrContainerT: " instead. This will stop working in 2023.",
if isinstance(subary, np.ndarray) or np.isscalar(subary): DeprecationWarning, stacklevel=2)
return actx.from_numpy(subary)
else:
raise TypeError(f"array is not an ndarray: '{type(subary).__name__}'")
return rec_map_array_container(_from_numpy_with_check, ary) return actx.from_numpy(ary)
def to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: def to_numpy(ary: ArrayOrContainer, actx: ArrayContext) -> ArrayOrContainer:
"""Convert all arrays in the :class:`~arraycontext.ArrayContainer` to """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to
:mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*. :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*.
The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`. The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`.
""" """
def _to_numpy_with_check(subary: Any) -> Any: warn("Calling to_numpy(ary, actx) is deprecated, call actx.to_numpy(ary)"
if isinstance(subary, actx.array_types) or np.isscalar(subary): " instead. This will stop working in 2023.",
return actx.to_numpy(subary) DeprecationWarning, stacklevel=2)
else:
raise TypeError(
f"array of type '{type(subary).__name__}' not in "
f"supported types {actx.array_types}")
return rec_map_array_container(_to_numpy_with_check, ary) return actx.to_numpy(ary)
# }}} # }}}
...@@ -823,8 +946,7 @@ def outer(a: Any, b: Any) -> Any: ...@@ -823,8 +946,7 @@ def outer(a: Any, b: Any) -> Any:
Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional
object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer` object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer`
always returns a matrix). Here the definition of "scalar" includes always returns a matrix). Here the definition of "scalar" includes
all non-array-container types and any scalar-like array container types all non-array-container types and any scalar-like array container types.
(including non-object numpy arrays).
If *a* and *b* are both array containers, the result will have the same type If *a* and *b* are both array containers, the result will have the same type
as *a*. If both are array containers and neither is an object array, they must as *a*. If both are array containers and neither is an object array, they must
...@@ -840,17 +962,24 @@ def outer(a: Any, b: Any) -> Any: ...@@ -840,17 +962,24 @@ def outer(a: Any, b: Any) -> Any:
return ( return (
not isinstance(x, np.ndarray) not isinstance(x, np.ndarray)
# This condition is whether "ndarrays should broadcast inside x". # This condition is whether "ndarrays should broadcast inside x".
and np.ndarray not in x.__class__._outer_bcast_types) and NumpyObjectArray not in x.__class__._outer_bcast_types)
a_is_ndarray = isinstance(a, np.ndarray)
b_is_ndarray = isinstance(b, np.ndarray)
if a_is_ndarray and a.dtype != object:
raise TypeError("passing a non-object numpy array is not allowed")
if b_is_ndarray and b.dtype != object:
raise TypeError("passing a non-object numpy array is not allowed")
if treat_as_scalar(a) or treat_as_scalar(b): if treat_as_scalar(a) or treat_as_scalar(b):
return a*b return a*b
# After this point, "isinstance(o, ndarray)" means o is an object array. elif a_is_ndarray and b_is_ndarray:
elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
return np.outer(a, b) return np.outer(a, b)
elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray): elif a_is_ndarray or b_is_ndarray:
return map_array_container(lambda x: outer(x, b), a) return map_array_container(lambda x: outer(x, b), a)
else: else:
if type(a) != type(b): if type(a) is not type(b):
raise TypeError( raise TypeError(
"both arguments must have the same type if they are both " "both arguments must have the same type if they are both "
"non-object-array array containers.") "non-object-array array containers.")
......
# mypy: disallow-untyped-defs
""" """
.. _freeze-thaw: .. _freeze-thaw:
...@@ -39,7 +41,7 @@ Here are some rules of thumb to use when dealing with thawing and freezing: ...@@ -39,7 +41,7 @@ Here are some rules of thumb to use when dealing with thawing and freezing:
- Note that array contexts need not necessarily be passed as a separate - Note that array contexts need not necessarily be passed as a separate
argument. Passing thawed data as an argument to a function suffices argument. Passing thawed data as an argument to a function suffices
to supply an array context. The array context can be extracted from to supply an array context. The array context can be extracted from
a thawed argument using, e.g., :func:`~arraycontext.get_container_context` a thawed argument using, e.g., :func:`~arraycontext.get_container_context_opt`
or :func:`~arraycontext.get_container_context_recursively`. or :func:`~arraycontext.get_container_context_recursively`.
What does this mean concretely? What does this mean concretely?
...@@ -70,26 +72,75 @@ actual array contexts: ...@@ -70,26 +72,75 @@ actual array contexts:
an array expression that has been built up by the user an array expression that has been built up by the user
(using, e.g. :func:`pytato.generate_loopy`). (using, e.g. :func:`pytato.generate_loopy`).
The interface of an array context
---------------------------------
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
.. class:: DeviceArray The :class:`ArrayContext` Interface
-----------------------------------
A (type alias for an) array type supported by the :class:`ArrayContext` .. autoclass:: ArrayContext
meant to aid in typing annotations. For a explicit list of supported types
see :attr:`ArrayContext.array_types`.
.. class:: DeviceScalar .. autofunction:: tag_axes
A (type alias for a) scalar type supported by the :class:`ArrayContext` Types and Type Variables for Arrays and Containers
meant to aid in typing annotations, e.g. for reductions. In :mod:`numpy` --------------------------------------------------
terminology, this is just an array with a shape of ``()``.
.. autoclass:: ArrayContext .. autoclass:: Array
.. autodata:: ArrayT
A type variable with a lower bound of :class:`Array`.
.. autodata:: ScalarLike
A type annotation for scalar types commonly usable with arrays.
See also :class:`ArrayContainer` and :class:`ArrayOrContainerT`.
.. autodata:: ArrayOrContainer
.. autodata:: ArrayOrContainerT
A type variable with a bound of :class:`ArrayOrContainer`.
.. autodata:: ArrayOrArithContainer
.. autodata:: ArrayOrArithContainerT
A type variable with a bound of :class:`ArrayOrArithContainer`.
.. autodata:: ArrayOrArithContainerOrScalar
.. autodata:: ArrayOrArithContainerOrScalarT
A type variable with a bound of :class:`ArrayOrContainerOrScalar`.
.. autodata:: ArrayOrContainerOrScalar
.. autodata:: ArrayOrContainerOrScalarT
A type variable with a bound of :class:`ArrayOrContainerOrScalar`.
.. currentmodule:: arraycontext.context
Canonical locations for type annotations
----------------------------------------
.. class:: ArrayT
:canonical: arraycontext.ArrayT
.. class:: ArrayOrContainerT
:canonical: arraycontext.ArrayOrContainerT
.. class:: ArrayOrContainerOrScalarT
:canonical: arraycontext.ArrayOrContainerOrScalarT
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
...@@ -115,17 +166,119 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -115,17 +166,119 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from typing import Sequence, Union, Callable, Any, Tuple from abc import ABC, abstractmethod
from abc import ABC, abstractmethod, abstractproperty from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload
from warnings import warn
import numpy as np import numpy as np
from typing_extensions import Self
from pytools import memoize_method from pytools import memoize_method
from pytools.tag import Tag from pytools.tag import ToTagSetConvertible
if TYPE_CHECKING:
import loopy
from arraycontext.container import ArithArrayContainer, ArrayContainer
DeviceArray = Any # {{{ typing
DeviceScalar = Any
_ScalarLike = Union[int, float, complex, np.generic] class Array(Protocol):
"""A :class:`~typing.Protocol` for the array type supported by
:class:`ArrayContext`.
This is meant to aid in typing annotations. For a explicit list of
supported types see :attr:`ArrayContext.array_types`.
.. attribute:: shape
.. attribute:: size
.. attribute:: dtype
.. attribute:: __getitem__
In addition, arrays are expected to support basic arithmetic.
"""
@property
def shape(self) -> tuple[int, ...]:
...
@property
def size(self) -> int:
...
@property
def dtype(self) -> np.dtype[Any]:
...
# Covering all the possible index variations is hard and (kind of) futile.
# If you'd like to see how, try changing the Any to
# AxisIndex = slice | int | "Array"
# Index = AxisIndex |tuple[AxisIndex]
def __getitem__(self, index: Any) -> Array:
...
# some basic arithmetic that's supposed to work
def __neg__(self) -> Self: ...
def __abs__(self) -> Self: ...
def __add__(self, other: Self | ScalarLike) -> Self: ...
def __radd__(self, other: Self | ScalarLike) -> Self: ...
def __sub__(self, other: Self | ScalarLike) -> Self: ...
def __rsub__(self, other: Self | ScalarLike) -> Self: ...
def __mul__(self, other: Self | ScalarLike) -> Self: ...
def __rmul__(self, other: Self | ScalarLike) -> Self: ...
def __truediv__(self, other: Self | ScalarLike) -> Self: ...
def __rtruediv__(self, other: Self | ScalarLike) -> Self: ...
# deprecated, use ScalarLike instead
ScalarLike: TypeAlias = int | float | complex | np.generic
Scalar = ScalarLike
ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike)
# NOTE: I'm kind of not sure about the *Tc versions of these type variables.
# mypy seems better at understanding arithmetic performed on the *Tc versions
# than the *T, versions, whereas pyright doesn't seem to care.
#
# This issue seems to be part of it:
# https://github.com/python/mypy/issues/18203
# but there is likely other stuff lurking.
#
# For now, they're purposefully not in the main arraycontext.* name space.
ArrayT = TypeVar("ArrayT", bound=Array)
ArrayOrScalar: TypeAlias = "Array | ScalarLike"
ArrayOrContainer: TypeAlias = "Array | ArrayContainer"
ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer"
ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer)
ArrayOrContainerTc = TypeVar("ArrayOrContainerTc",
Array, "ArrayContainer", "ArithArrayContainer")
ArrayOrArithContainerT = TypeVar("ArrayOrArithContainerT", bound=ArrayOrArithContainer)
ArrayOrArithContainerTc = TypeVar("ArrayOrArithContainerTc",
Array, "ArithArrayContainer")
ArrayOrContainerOrScalar: TypeAlias = "Array | ArrayContainer | ScalarLike"
ArrayOrArithContainerOrScalar: TypeAlias = "Array | ArithArrayContainer | ScalarLike"
ArrayOrContainerOrScalarT = TypeVar(
"ArrayOrContainerOrScalarT",
bound=ArrayOrContainerOrScalar)
ArrayOrArithContainerOrScalarT = TypeVar(
"ArrayOrArithContainerOrScalarT",
bound=ArrayOrArithContainerOrScalar)
ArrayOrContainerOrScalarTc = TypeVar(
"ArrayOrContainerOrScalarTc",
ScalarLike, Array, "ArrayContainer", "ArithArrayContainer")
ArrayOrArithContainerOrScalarTc = TypeVar(
"ArrayOrArithContainerOrScalarTc",
ScalarLike, Array, "ArithArrayContainer")
ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike")
NumpyOrContainerOrScalar = Union[np.ndarray, "ArrayContainer", ScalarLike]
# }}}
# {{{ ArrayContext # {{{ ArrayContext
...@@ -140,10 +293,6 @@ class ArrayContext(ABC): ...@@ -140,10 +293,6 @@ class ArrayContext(ABC):
.. versionadded:: 2020.2 .. versionadded:: 2020.2
.. automethod:: empty
.. automethod:: zeros
.. automethod:: empty_like
.. automethod:: zeros_like
.. automethod:: from_numpy .. automethod:: from_numpy
.. automethod:: to_numpy .. automethod:: to_numpy
.. automethod:: call_loopy .. automethod:: call_loopy
...@@ -151,71 +300,102 @@ class ArrayContext(ABC): ...@@ -151,71 +300,102 @@ class ArrayContext(ABC):
.. attribute:: np .. attribute:: np
Provides access to a namespace that serves as a work-alike to Provides access to a namespace that serves as a work-alike to
:mod:`numpy`. The actual level of functionality provided is up to the :mod:`numpy`. The actual level of functionality provided is up to the
individual array context implementation, however the functions and individual array context implementation, however the functions and
objects available under this namespace must not behave differently objects available under this namespace must not behave differently
from :mod:`numpy`. from :mod:`numpy`.
As a baseline, special functions available through :mod:`loopy` As a baseline, special functions available through :mod:`loopy`
(e.g. ``sin``, ``exp``) are accessible through this interface. (e.g. ``sin``, ``exp``) are accessible through this interface.
A full list of implemented functionality is given in
:ref:`numpy-coverage`.
Callables accessible through this namespace vectorize over object Callables accessible through this namespace vectorize over object
arrays, including :class:`arraycontext.ArrayContainer`\ s. arrays, including :class:`arraycontext.ArrayContainer`\ s.
.. attribute:: array_types .. attribute:: array_types
A :class:`tuple` of types that are the valid base array classes A :class:`tuple` of types that are the valid array classes the
the context can operate on. context can operate on. However, it is not necessary that *all* the
:class:`ArrayContext`\ 's operations are legal for the types in
*array_types*. Note that this tuple is *only* intended for use
with :func:`isinstance`. Other uses are not allowed. This allows
for 'types' with overridden :meth:`type.__instancecheck__`.
.. automethod:: freeze .. automethod:: freeze
.. automethod:: thaw .. automethod:: thaw
.. automethod:: freeze_thaw
.. automethod:: tag .. automethod:: tag
.. automethod:: tag_axis .. automethod:: tag_axis
.. automethod:: compile .. automethod:: compile
""" """
array_types: Tuple[type, ...] = () array_types: tuple[type, ...] = ()
def __init__(self): def __init__(self) -> None:
self.np = self._get_fake_numpy_namespace() self.np = self._get_fake_numpy_namespace()
def _get_fake_numpy_namespace(self):
from .fake_numpy import BaseFakeNumpyNamespace
return BaseFakeNumpyNamespace(self)
@abstractmethod @abstractmethod
def empty(self, shape, dtype): def _get_fake_numpy_namespace(self) -> Any:
pass ...
@abstractmethod def __hash__(self) -> int:
def zeros(self, shape, dtype): raise TypeError(f"unhashable type: '{type(self).__name__}'")
pass
def zeros(self,
shape: int | tuple[int, ...],
dtype: np.dtype[Any]) -> Array:
warn(f"{type(self).__name__}.zeros is deprecated and will stop "
"working in 2025. Use actx.np.zeros instead.",
DeprecationWarning, stacklevel=2)
def empty_like(self, ary): return self.np.zeros(shape, dtype)
return self.empty(shape=ary.shape, dtype=ary.dtype)
def zeros_like(self, ary): @overload
return self.zeros(shape=ary.shape, dtype=ary.dtype) def from_numpy(self, array: np.ndarray) -> Array:
...
@overload
def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...
@abstractmethod @abstractmethod
def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): def from_numpy(self,
array: NumpyOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
r""" r"""
:returns: the :class:`numpy.ndarray` *array* converted to the :returns: the :class:`numpy.ndarray` *array* converted to the
array context's array type. The returned array will be array context's array type. The returned array will be
:meth:`thaw`\ ed. :meth:`thaw`\ ed. When working with array containers each leaf
must be an :class:`~numpy.ndarray` or scalar, which is then converted
to the context's array type leaving the container structure
intact.
""" """
pass
@overload
def to_numpy(self, array: Array) -> np.ndarray:
...
@overload
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...
@abstractmethod @abstractmethod
def to_numpy(self, array): def to_numpy(self,
array: ArrayOrContainerOrScalar
) -> NumpyOrContainerOrScalar:
r""" r"""
:returns: *array*, an array recognized by the context, converted :returns: an :class:`numpy.ndarray` for each array recognized by the
to a :class:`numpy.ndarray`. *array* must be context. The input *array* must be :meth:`thaw`\ ed.
:meth:`thaw`\ ed. When working with array containers each leaf must be one of
the context's array types or a scalar, which is then converted to
an :class:`~numpy.ndarray` leaving the container structure intact.
""" """
pass
def call_loopy(self, program, **kwargs): @abstractmethod
def call_loopy(self,
t_unit: loopy.TranslationUnit,
**kwargs: Any) -> dict[str, Array]:
"""Execute the :mod:`loopy` program *program* on the arguments """Execute the :mod:`loopy` program *program* on the arguments
*kwargs*. *kwargs*.
...@@ -228,7 +408,7 @@ class ArrayContext(ABC): ...@@ -228,7 +408,7 @@ class ArrayContext(ABC):
""" """
@abstractmethod @abstractmethod
def freeze(self, array): def freeze(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
"""Return a version of the context-defined array *array* that is """Return a version of the context-defined array *array* that is
'frozen', i.e. suitable for long-term storage and reuse. Frozen arrays 'frozen', i.e. suitable for long-term storage and reuse. Frozen arrays
do not support arithmetic. For example, in the context of do not support arithmetic. For example, in the context of
...@@ -239,12 +419,10 @@ class ArrayContext(ABC): ...@@ -239,12 +419,10 @@ class ArrayContext(ABC):
Freezing makes the array independent of this :class:`ArrayContext`; Freezing makes the array independent of this :class:`ArrayContext`;
it is permitted to :meth:`thaw` it in a different one, as long as that it is permitted to :meth:`thaw` it in a different one, as long as that
context understands the array format. context understands the array format.
See also :func:`arraycontext.freeze`.
""" """
@abstractmethod @abstractmethod
def thaw(self, array): def thaw(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
"""Take a 'frozen' array and return a new array representing the data in """Take a 'frozen' array and return a new array representing the data in
*array* that is able to perform arithmetic and other operations, using *array* that is able to perform arithmetic and other operations, using
the execution resources of this context. In the context of the execution resources of this context. In the context of
...@@ -254,39 +432,65 @@ class ArrayContext(ABC): ...@@ -254,39 +432,65 @@ class ArrayContext(ABC):
the data in *array*. the data in *array*.
The returned array may not be used with other contexts while thawed. The returned array may not be used with other contexts while thawed.
"""
See also :func:`arraycontext.thaw`. def freeze_thaw(
self, array: ArrayOrContainerOrScalarT
) -> ArrayOrContainerOrScalarT:
r"""Evaluate an input array or container to "frozen" data return a new
"thawed" array or container representing the evaluation result that is
ready for use. This is a shortcut for calling :meth:`freeze` and
:meth:`thaw`.
This method can be useful in array contexts backed by, e.g.
:mod:`pytato`, to force the evaluation of a built-up array expression
(and thereby avoid reevaluations for expressions built on the array).
""" """
return self.thaw(self.freeze(array))
@abstractmethod @abstractmethod
def tag(self, tags: Union[Sequence[Tag], Tag], array): def tag(self,
tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
"""If the array type used by the array context is capable of capturing """If the array type used by the array context is capable of capturing
metadata, return a version of *array* with the *tags* applied. *array* metadata, return a version of *array* with the *tags* applied. *array*
itself is not modified. itself is not modified. When working with array containers, the
tags are applied to each leaf of the container.
See :ref:`metadata` as well as application-specific metadata types.
.. versionadded:: 2021.2 .. versionadded:: 2021.2
""" """
@abstractmethod @abstractmethod
def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): def tag_axis(self,
iaxis: int, tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
"""If the array type used by the array context is capable of capturing """If the array type used by the array context is capable of capturing
metadata, return a version of *array* in which axis number *iaxis* has metadata, return a version of *array* in which axis number *iaxis* has
the *tags* applied. *array* itself is not modified. the *tags* applied. *array* itself is not modified. When working with
array containers, the tags are applied to each leaf of the container.
See :ref:`metadata` as well as application-specific metadata types.
.. versionadded:: 2021.2 .. versionadded:: 2021.2
""" """
@memoize_method @memoize_method
def _get_einsum_prg(self, spec, arg_names, tagged): def _get_einsum_prg(self,
spec: str, arg_names: tuple[str, ...],
tagged: ToTagSetConvertible) -> loopy.TranslationUnit:
import loopy as lp import loopy as lp
from .loopy import _DEFAULT_LOOPY_OPTIONS
from loopy.version import MOST_RECENT_LANGUAGE_VERSION from loopy.version import MOST_RECENT_LANGUAGE_VERSION
from .loopy import _DEFAULT_LOOPY_OPTIONS
return lp.make_einsum( return lp.make_einsum(
spec, spec,
arg_names, arg_names,
options=_DEFAULT_LOOPY_OPTIONS, options=_DEFAULT_LOOPY_OPTIONS,
lang_version=MOST_RECENT_LANGUAGE_VERSION, lang_version=MOST_RECENT_LANGUAGE_VERSION,
tags=tagged, tags=tagged,
default_order=lp.auto,
default_offset=lp.auto, default_offset=lp.auto,
) )
...@@ -303,7 +507,10 @@ class ArrayContext(ABC): ...@@ -303,7 +507,10 @@ class ArrayContext(ABC):
# That's why einsum's interface here needs to be cluttered with # That's why einsum's interface here needs to be cluttered with
# metadata, and that's why it can't live under .np. # metadata, and that's why it can't live under .np.
# [1] https://github.com/inducer/meshmode/issues/177 # [1] https://github.com/inducer/meshmode/issues/177
def einsum(self, spec, *args, arg_names=None, tagged=()): def einsum(self,
spec: str, *args: Array,
arg_names: tuple[str, ...] | None = None,
tagged: ToTagSetConvertible = ()) -> Array:
"""Computes the result of Einstein summation following the """Computes the result of Einstein summation following the
convention in :func:`numpy.einsum`. convention in :func:`numpy.einsum`.
...@@ -323,15 +530,16 @@ class ArrayContext(ABC): ...@@ -323,15 +530,16 @@ class ArrayContext(ABC):
:return: the output of the einsum :mod:`loopy` program :return: the output of the einsum :mod:`loopy` program
""" """
if arg_names is None: if arg_names is None:
arg_names = tuple("arg%d" % i for i in range(len(args))) arg_names = tuple(f"arg{i}" for i in range(len(args)))
prg = self._get_einsum_prg(spec, arg_names, tagged) prg = self._get_einsum_prg(spec, arg_names, tagged)
return self.call_loopy( out_ary = self.call_loopy(
prg, **{arg_names[i]: arg for i, arg in enumerate(args)} prg, **{arg_names[i]: arg for i, arg in enumerate(args)}
)["out"] )["out"]
return self.tag(tagged, out_ary)
@abstractmethod @abstractmethod
def clone(self): def clone(self) -> Self:
"""If possible, return a version of *self* that is semantically """If possible, return a version of *self* that is semantically
equivalent (i.e. implements all array operations in the same way) equivalent (i.e. implements all array operations in the same way)
but is a separate object. May return *self* if that is not possible. but is a separate object. May return *self* if that is not possible.
...@@ -370,23 +578,52 @@ class ArrayContext(ABC): ...@@ -370,23 +578,52 @@ class ArrayContext(ABC):
return f return f
# undocumented for now # undocumented for now
@abstractproperty @property
def permits_inplace_modification(self): @abstractmethod
pass def permits_inplace_modification(self) -> bool:
"""
*True* if the arrays allow in-place modifications.
"""
# undocumented for now # undocumented for now
@abstractproperty @property
def supports_nonscalar_broadcasting(self): @abstractmethod
pass def supports_nonscalar_broadcasting(self) -> bool:
"""
*True* if the arrays support non-scalar broadcasting.
"""
@abstractproperty # undocumented for now
def permits_advanced_indexing(self): @property
@abstractmethod
def permits_advanced_indexing(self) -> bool:
""" """
*True* only if the arrays support :mod:`numpy`'s advanced indexing *True* if the arrays support :mod:`numpy`'s advanced indexing semantics.
semantics.
""" """
pass
# }}} # }}}
# {{{ tagging helpers
def tag_axes(
actx: ArrayContext,
dim_to_tags: Mapping[int, ToTagSetConvertible],
ary: ArrayT) -> ArrayT:
"""
Return a copy of *ary* with the axes in *dim_to_tags* tagged with their
corresponding tags. Equivalent to repeated application of
:meth:`ArrayContext.tag_axis`.
"""
for iaxis, tags in dim_to_tags.items():
ary = actx.tag_axis(iaxis, tags, ary)
return ary
# }}}
class UntransformedCodeWarning(UserWarning):
pass
# vim: foldmethod=marker # vim: foldmethod=marker
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
""" """
...@@ -23,59 +26,25 @@ THE SOFTWARE. ...@@ -23,59 +26,25 @@ THE SOFTWARE.
""" """
import operator
from abc import ABC, abstractmethod
from typing import Any
import numpy as np import numpy as np
from arraycontext.container import NotAnArrayContainerError, serialize_container
from arraycontext.container.traversal import (
rec_map_array_container, multimapped_over_array_containers)
from pytools import memoize_in
# {{{ _get_scalar_func_loopy_program
def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes):
@memoize_in(actx, _get_scalar_func_loopy_program)
def get(c_name, nargs, naxes):
from pymbolic import var
var_names = ["i%d" % i for i in range(naxes)]
size_names = ["n%d" % i for i in range(naxes)]
subscript = tuple(var(vname) for vname in var_names)
from islpy import make_zero_and_vars
v = make_zero_and_vars(var_names, params=size_names)
domain = v[0].domain()
for vname, sname in zip(var_names, size_names):
domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname])
domain_bset, = domain.get_basic_sets()
import loopy as lp
from .loopy import make_loopy_program
from arraycontext.transform_metadata import ElementwiseMapKernelTag
return make_loopy_program(
[domain_bset],
[
lp.Assignment(
var("out")[subscript],
var(c_name)(*[
var("inp%d" % i)[subscript] for i in range(nargs)]))
],
name="actx_special_%s" % c_name,
tags=(ElementwiseMapKernelTag(),))
return get(c_name, nargs, naxes)
# }}} from arraycontext.container import NotAnArrayContainerError, serialize_container
from arraycontext.container.traversal import rec_map_array_container
# {{{ BaseFakeNumpyNamespace # {{{ BaseFakeNumpyNamespace
class BaseFakeNumpyNamespace: class BaseFakeNumpyNamespace(ABC):
def __init__(self, array_context): def __init__(self, array_context):
self._array_context = array_context self._array_context = array_context
self.linalg = self._get_fake_numpy_linalg_namespace() self.linalg = self._get_fake_numpy_linalg_namespace()
def _get_fake_numpy_linalg_namespace(self): def _get_fake_numpy_linalg_namespace(self):
return BaseFakeNumpyLinalgNamespace(self.array_context) return BaseFakeNumpyLinalgNamespace(self._array_context)
_numpy_math_functions = frozenset({ _numpy_math_functions = frozenset({
# https://numpy.org/doc/stable/reference/routines.math.html # https://numpy.org/doc/stable/reference/routines.math.html
...@@ -124,77 +93,20 @@ class BaseFakeNumpyNamespace: ...@@ -124,77 +93,20 @@ class BaseFakeNumpyNamespace:
# Miscellaneous # Miscellaneous
"convolve", "clip", "sqrt", "cbrt", "square", "absolute", "abs", "fabs", "convolve", "clip", "sqrt", "cbrt", "square", "absolute", "abs", "fabs",
"sign", "heaviside", "maximum", "fmax", "nan_to_num", "sign", "heaviside", "maximum", "fmax", "nan_to_num", "isnan", "minimum",
"fmin",
# FIXME: # FIXME:
# "interp", # "interp",
}) })
_numpy_to_c_arc_functions = { @abstractmethod
"arcsin": "asin", def zeros(self, shape, dtype):
"arccos": "acos", ...
"arctan": "atan",
"arctan2": "atan2",
"arcsinh": "asinh",
"arccosh": "acosh",
"arctanh": "atanh",
}
_c_to_numpy_arc_functions = {c_name: numpy_name
for numpy_name, c_name in _numpy_to_c_arc_functions.items()}
def __getattr__(self, name):
def loopy_implemented_elwise_func(*args):
if all(np.isscalar(ary) for ary in args):
return getattr(
np, self._c_to_numpy_arc_functions.get(name, name)
)(*args)
actx = self._array_context
prg = _get_scalar_func_loopy_program(actx,
c_name, nargs=len(args), naxes=len(args[0].shape))
outputs = actx.call_loopy(prg,
**{"inp%d" % i: arg for i, arg in enumerate(args)})
return outputs["out"]
if name in self._c_to_numpy_arc_functions:
from warnings import warn
warn(f"'{name}' in ArrayContext.np is deprecated. "
f"Use '{self._c_to_numpy_arc_functions[name]}' as in numpy. "
"The old name will stop working in 2021.",
DeprecationWarning, stacklevel=3)
# normalize to C names anyway
c_name = self._numpy_to_c_arc_functions.get(name, name)
# limit which functions we try to hand off to loopy
if (name in self._numpy_math_functions
or name in self._c_to_numpy_arc_functions):
return multimapped_over_array_containers(loopy_implemented_elwise_func)
else:
raise AttributeError(name)
def _new_like(self, ary, alloc_like):
if np.isscalar(ary):
# NOTE: `np.zeros_like(x)` returns `array(x, shape=())`, which
# is best implemented by concrete array contexts, if at all
raise NotImplementedError("operation not implemented for scalars")
if isinstance(ary, np.ndarray) and ary.dtype.char == "O":
# NOTE: we don't want to match numpy semantics on object arrays,
# e.g. `np.zeros_like(x)` returns `array([0, 0, ...], dtype=object)`
# FIXME: what about object arrays nested in an ArrayContainer?
raise NotImplementedError("operation not implemented for object arrays")
return rec_map_array_container(alloc_like, ary)
def empty_like(self, ary):
return self._new_like(ary, self._array_context.empty_like)
@abstractmethod
def zeros_like(self, ary): def zeros_like(self, ary):
return self._new_like(ary, self._array_context.zeros_like) ...
def conjugate(self, x): def conjugate(self, x):
# NOTE: conjugate distributes over object arrays, but it looks for a # NOTE: conjugate distributes over object arrays, but it looks for a
...@@ -204,14 +116,97 @@ class BaseFakeNumpyNamespace: ...@@ -204,14 +116,97 @@ class BaseFakeNumpyNamespace:
conj = conjugate conj = conjugate
# {{{ linspace
# based on
# https://github.com/numpy/numpy/blob/v1.25.0/numpy/core/function_base.py#L24-L182
def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis=0):
num = operator.index(num)
if num < 0:
raise ValueError(f"Number of samples, {num}, must be non-negative.")
div = (num - 1) if endpoint else num
# Convert float/complex array scalars to float, gh-3504
# and make sure one can use variables that have an __array_interface__,
# gh-6634
if isinstance(start, self._array_context.array_types):
raise NotImplementedError("start as an actx array")
if isinstance(stop, self._array_context.array_types):
raise NotImplementedError("stop as an actx array")
start = np.array(start) * 1.0
stop = np.array(stop) * 1.0
dt = np.result_type(start, stop, float(num))
if dtype is None:
dtype = dt
integer_dtype = False
else:
integer_dtype = np.issubdtype(dtype, np.integer)
delta = stop - start
y = self.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim)
if div > 0:
step = delta / div
# any_step_zero = _nx.asanyarray(step == 0).any()
any_step_zero = self._array_context.to_numpy(step == 0).any()
if any_step_zero:
delta_actx = self._array_context.from_numpy(delta)
# Special handling for denormal numbers, gh-5437
y = y / div
y = y * delta_actx
else:
step_actx = self._array_context.from_numpy(step)
y = y * step_actx
else:
delta_actx = self._array_context.from_numpy(delta)
# sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0)
# have an undefined step
step = np.nan
# Multiply with delta to allow possible override of output class.
y = y * delta_actx
y += start
# FIXME reenable, without in-place ops
# if endpoint and num > 1:
# y[-1, ...] = stop
if axis != 0:
# y = _nx.moveaxis(y, 0, axis)
raise NotImplementedError("axis != 0")
if integer_dtype:
y = self.floor(y) # pylint: disable=no-member
# FIXME: Use astype
# https://github.com/inducer/pytato/issues/456
if retstep:
return y, step
# return y.astype(dtype), step
else:
return y
# return y.astype(dtype)
# }}}
def arange(self, *args: Any, **kwargs: Any):
raise NotImplementedError
# }}} # }}}
# {{{ BaseFakeNumpyLinalgNamespace # {{{ BaseFakeNumpyLinalgNamespace
def _reduce_norm(actx, arys, ord): def _reduce_norm(actx, arys, ord):
from numbers import Number
from functools import reduce from functools import reduce
from numbers import Number
if ord is None: if ord is None:
ord = 2 ord = 2
...@@ -284,6 +279,7 @@ class BaseFakeNumpyLinalgNamespace: ...@@ -284,6 +279,7 @@ class BaseFakeNumpyLinalgNamespace:
return actx.np.sum(abs(ary)**ord)**(1/ord) return actx.np.sum(abs(ary)**ord)**(1/ord)
else: else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}") raise NotImplementedError(f"unsupported value of 'ord': {ord}")
# }}} # }}}
......
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
""" """
......
"""
.. currentmodule:: arraycontext
.. autoclass:: EagerJAXArrayContext
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from collections.abc import Callable
import numpy as np
from pytools.tag import ToTagSetConvertible
from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
class EagerJAXArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses
:class:`jax.Array` instances for its base array
class and performs all array operations eagerly. See
:class:`~arraycontext.PytatoJAXArrayContext` for a lazier version.
.. note::
JAX stores a global configuration state in :data:`jax.config`. Callers
are expected to maintain those. Most important for scientific computing
workloads being ``jax_enable_x64``.
"""
def __init__(self) -> None:
super().__init__()
import jax.numpy as jnp
self.array_types = (jnp.ndarray, )
def _get_fake_numpy_namespace(self):
from .fake_numpy import EagerJAXFakeNumpyNamespace
return EagerJAXFakeNumpyNamespace(self)
def _rec_map_container(
self, func: Callable[[Array], Array], array: ArrayOrContainer,
allowed_types: tuple[type, ...] | None = None, *,
default_scalar: ScalarLike | None = None,
strict: bool = False) -> ArrayOrContainer:
if allowed_types is None:
allowed_types = self.array_types
def _wrapper(ary):
if isinstance(ary, allowed_types):
return func(ary)
elif np.isscalar(ary):
if default_scalar is None:
return ary
else:
return np.array(ary).dtype.type(default_scalar)
else:
raise TypeError(
f"{type(self).__name__}.{func.__name__[1:]} invoked with "
f"an unsupported array type: got '{type(ary).__name__}', "
f"but expected one of {allowed_types}")
return rec_map_array_container(_wrapper, array)
# {{{ ArrayContext interface
def from_numpy(self, array):
def _from_numpy(ary):
import jax
return jax.device_put(ary)
return with_array_context(
self._rec_map_container(_from_numpy, array, allowed_types=(np.ndarray,)),
actx=self)
def to_numpy(self, array):
def _to_numpy(ary):
import jax
return jax.device_get(ary)
return with_array_context(
self._rec_map_container(_to_numpy, array),
actx=None)
def freeze(self, array):
def _freeze(ary):
return ary.block_until_ready()
return with_array_context(self._rec_map_container(_freeze, array), actx=None)
def thaw(self, array):
return with_array_context(array, actx=self)
def tag(self, tags: ToTagSetConvertible, array):
# Sorry, not capable.
return array
def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
# TODO: See `jax.experimental.maps.xmap`, probably that should be useful?
return array
def call_loopy(self, t_unit, **kwargs):
raise NotImplementedError(
"Calling loopy on JAX arrays is not supported. Maybe rewrite"
" the loopy kernel as numpy-flavored array operations using"
" ArrayContext.np.")
def einsum(self, spec, *args, arg_names=None, tagged=()):
import jax.numpy as jnp
if arg_names is not None:
from warnings import warn
warn("'arg_names' don't bear any significance in "
f"{type(self).__name__}.", stacklevel=2)
return jnp.einsum(spec, *args)
def clone(self):
return type(self)()
# }}}
# {{{ properties
@property
def permits_inplace_modification(self):
return False
@property
def supports_nonscalar_broadcasting(self):
return True
@property
def permits_advanced_indexing(self):
return True
# }}}
# vim: foldmethod=marker
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial, reduce
import numpy as np
import jax.numpy as jnp
from arraycontext.container import (
NotAnArrayContainerError,
serialize_container,
)
from arraycontext.container.traversal import (
rec_map_array_container,
rec_map_reduce_array_container,
rec_multimap_array_container,
)
from arraycontext.context import Array, ArrayOrContainer
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace
class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
# Everything is implemented in the base class for now.
pass
class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
"""
A :mod:`numpy` mimic for :class:`~arraycontext.EagerJAXArrayContext`.
"""
def _get_fake_numpy_linalg_namespace(self):
return EagerJAXFakeNumpyLinalgNamespace(self._array_context)
def __getattr__(self, name):
return partial(rec_multimap_array_container, getattr(jnp, name))
# NOTE: the order of these follows the order in numpy docs
# NOTE: when adding a function here, also add it to `array_context.rst` docs!
# {{{ array creation routines
def zeros(self, shape, dtype):
return jnp.zeros(shape=shape, dtype=dtype)
def empty_like(self, ary):
from warnings import warn
warn(f"{type(self._array_context).__name__}.np.empty_like is "
"deprecated and will stop working in 2023. Prefer actx.np.zeros_like "
"instead.",
DeprecationWarning, stacklevel=2)
def _empty_like(array):
return self._array_context.empty(array.shape, array.dtype)
return self._array_context._rec_map_container(_empty_like, ary)
def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.zeros(array.shape, array.dtype)
return self._array_context._rec_map_container(
_zeros_like, ary, default_scalar=0)
def ones_like(self, ary):
return self.full_like(ary, 1)
def full_like(self, ary, fill_value):
def _full_like(subary):
return jnp.full_like(subary, fill_value)
return self._array_context._rec_map_container(
_full_like, ary, default_scalar=fill_value)
# }}}
# {{{ array manipulation routies
def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: jnp.reshape(ary, newshape, order=order),
a)
def ravel(self, a, order="C"):
"""
.. warning::
Since :func:`jax.numpy.reshape` does not support orders `A`` and
``K``, in such cases we fallback to using ``order = C``.
"""
if order in "AK":
from warnings import warn
warn(f"ravel with order='{order}' not supported by JAX,"
" using order=C.", stacklevel=1)
order = "C"
return rec_map_array_container(
lambda subary: jnp.ravel(subary, order=order), a)
def transpose(self, a, axes=None):
return rec_multimap_array_container(jnp.transpose, a, axes)
def broadcast_to(self, array, shape):
return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(jnp.concatenate, arrays, axis)
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: jnp.stack(arrays=args, axis=axis),
*arrays)
# }}}
# {{{ linear algebra
def vdot(self, x, y, dtype=None):
from arraycontext import rec_multimap_reduce_array_container
def _rec_vdot(ary1, ary2):
common_dtype = np.result_type(ary1, ary2)
if dtype not in (None, common_dtype):
raise NotImplementedError(
f"{type(self).__name__} cannot take dtype in vdot.")
return jnp.vdot(ary1, ary2)
return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
# }}}
# {{{ logic functions
def all(self, a):
return rec_map_reduce_array_container(
partial(reduce, jnp.logical_and), jnp.all, a)
def any(self, a):
return rec_map_reduce_array_container(
partial(reduce, jnp.logical_or), jnp.any, a)
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
actx = self._array_context
# NOTE: not all backends support `bool` properly, so use `int8` instead
true_ary = actx.from_numpy(np.int8(True))
false_ary = actx.from_numpy(np.int8(False))
def rec_equal(x, y):
if type(x) is not type(y):
return false_ary
try:
serialized_x = serialize_container(x)
serialized_y = serialize_container(y)
except NotAnArrayContainerError:
if x.shape != y.shape:
return false_ary
else:
return jnp.all(jnp.equal(x, y))
else:
if len(serialized_x) != len(serialized_y):
return false_ary
return reduce(
jnp.logical_and,
[(true_ary if kx_i == ky_i else false_ary)
and rec_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y, strict=True)],
true_ary)
return rec_equal(a, b)
# }}}
# {{{ mathematical functions
def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(
sum,
partial(jnp.sum, axis=axis, dtype=dtype),
a)
def amin(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
min = amin
def amax(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
max = amax
# }}}
# {{{ sorting, searching and counting
def where(self, criterion, then, else_):
return rec_multimap_array_container(jnp.where, criterion, then, else_)
# }}}
"""
.. currentmodule:: arraycontext
A :mod:`numpy`-based array context.
.. autoclass:: NumpyArrayContext
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from typing import Any, overload
import numpy as np
import loopy as lp
from pytools.tag import ToTagSetConvertible
from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ContainerOrScalarT,
NumpyOrContainerOrScalar,
UntransformedCodeWarning,
)
class NumpyNonObjectArrayMetaclass(type):
def __instancecheck__(cls, instance: Any) -> bool:
return isinstance(instance, np.ndarray) and instance.dtype != object
class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass):
pass
class NumpyArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays.
.. automethod:: __init__
"""
_loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]
def __init__(self) -> None:
super().__init__()
self._loopy_transform_cache = {}
array_types = (NumpyNonObjectArray,)
def _get_fake_numpy_namespace(self):
from .fake_numpy import NumpyFakeNumpyNamespace
return NumpyFakeNumpyNamespace(self)
# {{{ ArrayContext interface
def clone(self):
return type(self)()
@overload
def from_numpy(self, array: np.ndarray) -> Array:
...
@overload
def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...
def from_numpy(self,
array: NumpyOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
return array
@overload
def to_numpy(self, array: Array) -> np.ndarray:
...
@overload
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...
def to_numpy(self,
array: ArrayOrContainerOrScalar
) -> NumpyOrContainerOrScalar:
return array
def call_loopy(
self,
t_unit: lp.TranslationUnit, **kwargs: Any
) -> dict[str, Array]:
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
try:
executor = self._loopy_transform_cache[t_unit]
except KeyError:
executor = self.transform_loopy_program(t_unit).executor()
self._loopy_transform_cache[t_unit] = executor
_, result = executor(**kwargs)
return result
def freeze(self, array):
def _freeze(ary):
return ary
return with_array_context(rec_map_array_container(_freeze, array), actx=None)
def thaw(self, array):
def _thaw(ary):
return ary
return with_array_context(rec_map_array_container(_thaw, array), actx=self)
# }}}
def transform_loopy_program(self, t_unit):
from warnings import warn
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is a no-op and will result in unoptimized C code for"
"the requested optimization, all in a single statement."
"This will work, but is unlikely to be performant."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
UntransformedCodeWarning, stacklevel=2)
return t_unit
def tag(self,
tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
# Numpy doesn't support tagging
return array
def tag_axis(self,
iaxis: int, tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
# Numpy doesn't support tagging
return array
def einsum(self, spec, *args, arg_names=None, tagged=()):
return np.einsum(spec, *args)
@property
def permits_inplace_modification(self):
return True
@property
def supports_nonscalar_broadcasting(self):
return True
@property
def permits_advanced_indexing(self):
return True
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial, reduce
from typing import cast
import numpy as np
from arraycontext.container import NotAnArrayContainerError, serialize_container
from arraycontext.container.traversal import (
rec_map_array_container,
rec_map_reduce_array_container,
rec_multimap_array_container,
rec_multimap_reduce_array_container,
)
from arraycontext.context import Array, ArrayOrContainer
from arraycontext.fake_numpy import (
BaseFakeNumpyLinalgNamespace,
BaseFakeNumpyNamespace,
)
class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
# Everything is implemented in the base class for now.
pass
_NUMPY_UFUNCS = frozenset({"concatenate", "reshape", "transpose",
"ones_like", "where",
*BaseFakeNumpyNamespace._numpy_math_functions
})
class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace):
"""
A :mod:`numpy` mimic for :class:`NumpyArrayContext`.
"""
def _get_fake_numpy_linalg_namespace(self):
return NumpyFakeNumpyLinalgNamespace(self._array_context)
def zeros(self, shape, dtype):
return np.zeros(shape, dtype)
def __getattr__(self, name):
if name in _NUMPY_UFUNCS:
from functools import partial
return partial(rec_multimap_array_container,
getattr(np, name))
raise AttributeError(name)
def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(sum, partial(np.sum,
axis=axis,
dtype=dtype),
a)
def min(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, np.minimum), partial(np.amin, axis=axis), a)
def max(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, np.maximum), partial(np.amax, axis=axis), a)
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: np.stack(arrays=args, axis=axis),
*arrays)
def broadcast_to(self, array, shape):
return rec_map_array_container(partial(np.broadcast_to, shape=shape), array)
# {{{ relational operators
def equal(self, x, y):
return rec_multimap_array_container(np.equal, x, y)
def not_equal(self, x, y):
return rec_multimap_array_container(np.not_equal, x, y)
def greater(self, x, y):
return rec_multimap_array_container(np.greater, x, y)
def greater_equal(self, x, y):
return rec_multimap_array_container(np.greater_equal, x, y)
def less(self, x, y):
return rec_multimap_array_container(np.less, x, y)
def less_equal(self, x, y):
return rec_multimap_array_container(np.less_equal, x, y)
# }}}
def ravel(self, a, order="C"):
return rec_map_array_container(partial(np.ravel, order=order), a)
def vdot(self, x, y):
return rec_multimap_reduce_array_container(sum, np.vdot, x, y)
def any(self, a):
return rec_map_reduce_array_container(partial(reduce, np.logical_or),
lambda subary: np.any(subary), a)
def all(self, a):
return rec_map_reduce_array_container(partial(reduce, np.logical_and),
lambda subary: np.all(subary), a)
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
false_ary = np.array(False)
true_ary = np.array(True)
if type(a) is not type(b):
return false_ary
try:
serialized_x = serialize_container(a)
serialized_y = serialize_container(b)
except NotAnArrayContainerError:
assert isinstance(a, np.ndarray)
assert isinstance(b, np.ndarray)
return np.array(np.array_equal(a, b))
else:
if len(serialized_x) != len(serialized_y):
return false_ary
return np.logical_and.reduce(
[(true_ary if kx_i == ky_i else false_ary)
and cast(np.ndarray, self.array_equal(x_i, y_i))
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y, strict=True)],
initial=true_ary)
def arange(self, *args, **kwargs):
return np.arange(*args, **kwargs)
def linspace(self, *args, **kwargs):
return np.linspace(*args, **kwargs)
def zeros_like(self, ary):
return rec_map_array_container(np.zeros_like, ary)
def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: ary.reshape(newshape, order=order),
a)
# vim: fdm=marker
""" from __future__ import annotations
__doc__ = """
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
.. autoclass:: PyOpenCLArrayContext .. autoclass:: PyOpenCLArrayContext
.. automodule:: arraycontext.impl.pyopencl.taggable_cl_array
""" """
__copyright__ = """ __copyright__ = """
...@@ -27,19 +31,27 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -27,19 +31,27 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from collections.abc import Callable
from typing import TYPE_CHECKING
from warnings import warn from warnings import warn
from typing import Dict, List, Sequence, Optional, Union, TYPE_CHECKING
import numpy as np import numpy as np
from pytools.tag import Tag from pytools.tag import ToTagSetConvertible
from arraycontext.context import ArrayContext, _ScalarLike from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ScalarLike,
UntransformedCodeWarning,
)
if TYPE_CHECKING: if TYPE_CHECKING:
import pyopencl
import loopy as lp import loopy as lp
import pyopencl
# {{{ PyOpenCLArrayContext # {{{ PyOpenCLArrayContext
...@@ -65,13 +77,15 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -65,13 +77,15 @@ class PyOpenCLArrayContext(ArrayContext):
of arrays are created (e.g. as results of computation), the associated cost of arrays are created (e.g. as results of computation), the associated cost
may become significant. Using e.g. :class:`pyopencl.tools.MemoryPool` may become significant. Using e.g. :class:`pyopencl.tools.MemoryPool`
as the allocator can help avoid this cost. as the allocator can help avoid this cost.
.. automethod:: transform_loopy_program
""" """
def __init__(self, def __init__(self,
queue: "pyopencl.CommandQueue", queue: pyopencl.CommandQueue,
allocator: Optional["pyopencl.tools.AllocatorInterface"] = None, allocator: pyopencl.tools.AllocatorBase | None = None,
wait_event_queue_length: Optional[int] = None, wait_event_queue_length: int | None = None,
force_device_scalars: bool = False) -> None: force_device_scalars: bool | None = None) -> None:
r""" r"""
:arg wait_event_queue_length: The length of a queue of :arg wait_event_queue_length: The length of a queue of
:class:`~pyopencl.Event` objects that are maintained by the :class:`~pyopencl.Event` objects that are maintained by the
...@@ -92,24 +106,18 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -92,24 +106,18 @@ class PyOpenCLArrayContext(ArrayContext):
For now, *wait_event_queue_length* should be regarded as an For now, *wait_event_queue_length* should be regarded as an
experimental feature that may change or disappear at any minute. experimental feature that may change or disappear at any minute.
:arg force_device_scalars: if *True*, scalar results returned from
reductions in :attr:`ArrayContext.np` will be kept on the device.
If *False*, the equivalent of :meth:`~ArrayContext.freeze` and
:meth:`~ArrayContext.to_numpy` is applied to transfer the results
to the host.
""" """
if not force_device_scalars: if force_device_scalars is not None:
warn("Configuring the PyOpenCLArrayContext to return host scalars " warn(
"from reductions is deprecated. " "`force_device_scalars` is deprecated and will be removed in 2025.",
"To configure the PyOpenCLArrayContext to return " DeprecationWarning, stacklevel=2)
"device scalars, pass 'force_device_scalars=True' to the "
"constructor. " if not force_device_scalars:
"Support for returning host scalars will be removed in 2022.", raise ValueError(
DeprecationWarning, stacklevel=2) "Passing force_device_scalars=False is not allowed.")
import pyopencl as cl import pyopencl as cl
import pyopencl.array as cla import pyopencl.array as cl_array
super().__init__() super().__init__()
self.context = queue.context self.context = queue.context
...@@ -118,67 +126,134 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -118,67 +126,134 @@ class PyOpenCLArrayContext(ArrayContext):
if wait_event_queue_length is None: if wait_event_queue_length is None:
wait_event_queue_length = 10 wait_event_queue_length = 10
self._force_device_scalars = force_device_scalars self._force_device_scalars = True
# Subclasses might still be using the old
# "force_devices_scalars: bool = False" interface, in which case we need
# to explicitly pass force_device_scalars=True in clone()
self._passed_force_device_scalars = force_device_scalars is not None
self._wait_event_queue_length = wait_event_queue_length self._wait_event_queue_length = wait_event_queue_length
self._kernel_name_to_wait_event_queue: Dict[str, List[cl.Event]] = {} self._kernel_name_to_wait_event_queue: dict[str, list[cl.Event]] = {}
if queue.device.type & cl.device_type.GPU: if queue.device.type & cl.device_type.GPU:
if allocator is None: if allocator is None:
warn("PyOpenCLArrayContext created without an allocator on a GPU. " warn("PyOpenCLArrayContext created without an allocator on a GPU. "
"This can lead to high numbers of memory allocations. " "This can lead to high numbers of memory allocations. "
"Please consider using a pyopencl.tools.MemoryPool. " "Please consider using a pyopencl.tools.MemoryPool. "
"Run with allocator=False to disable this warning.") "Run with allocator=False to disable this warning.",
stacklevel=2)
if __debug__: if __debug__:
# Use "running on GPU" as a proxy for "they care about speed". # Use "running on GPU" as a proxy for "they care about speed".
warn("You are using the PyOpenCLArrayContext on a GPU, but you " warn("You are using the PyOpenCLArrayContext on a GPU, but you "
"are running Python in debug mode. Use 'python -O' for " "are running Python in debug mode. Use 'python -O' for "
"a noticeable speed improvement.") "a noticeable speed improvement.",
stacklevel=2)
self._loopy_transform_cache: \ self._loopy_transform_cache: \
Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {} dict[lp.TranslationUnit, lp.TranslationUnit] = {}
self.array_types = (cla.Array,) # TODO: Ideally this should only be `(TaggableCLArray,)`, but
# that would break the logic in the downstream users.
self.array_types = (cl_array.Array,)
def _get_fake_numpy_namespace(self): def _get_fake_numpy_namespace(self):
from arraycontext.impl.pyopencl.fake_numpy import PyOpenCLFakeNumpyNamespace from arraycontext.impl.pyopencl.fake_numpy import PyOpenCLFakeNumpyNamespace
return PyOpenCLFakeNumpyNamespace(self) return PyOpenCLFakeNumpyNamespace(self)
def _rec_map_container(
self, func: Callable[[Array], Array], array: ArrayOrContainer,
allowed_types: tuple[type, ...] | None = None, *,
default_scalar: ScalarLike | None = None,
strict: bool = False) -> ArrayOrContainer:
import arraycontext.impl.pyopencl.taggable_cl_array as tga
if allowed_types is None:
# TODO: replace with 'self.array_types' once `cla.Array` support
# is completely removed
allowed_types = (tga.TaggableCLArray,)
def _wrapper(ary):
if isinstance(ary, allowed_types):
return func(ary)
elif not strict and isinstance(ary, self.array_types):
from warnings import warn
warn(f"Invoking {type(self).__name__}.{func.__name__[1:]} with "
f"{type(ary).__name__} will be unsupported in 2023. Use "
"'to_tagged_cl_array' to convert instances to TaggableCLArray.",
DeprecationWarning, stacklevel=2)
return func(tga.to_tagged_cl_array(ary))
elif np.isscalar(ary):
if default_scalar is None:
return ary
else:
return np.array(ary).dtype.type(default_scalar)
else:
raise TypeError(
f"{type(self).__name__}.{func.__name__[1:]} invoked with "
f"an unsupported array type: got '{type(ary).__name__}', "
f"but expected one of {allowed_types}")
return rec_map_array_container(_wrapper, array)
# {{{ ArrayContext interface # {{{ ArrayContext interface
def empty(self, shape, dtype): def from_numpy(self, array):
import pyopencl.array as cl_array import arraycontext.impl.pyopencl.taggable_cl_array as tga
return cl_array.empty(self.queue, shape=shape, dtype=dtype,
allocator=self.allocator)
def zeros(self, shape, dtype): def _from_numpy(ary):
import pyopencl.array as cl_array return tga.to_device(self.queue, ary, allocator=self.allocator)
return cl_array.zeros(self.queue, shape=shape, dtype=dtype,
allocator=self.allocator)
def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): return with_array_context(
import pyopencl.array as cl_array self._rec_map_container(_from_numpy, array, (np.ndarray,), strict=True),
return cl_array.to_device(self.queue, array, allocator=self.allocator) actx=self)
def to_numpy(self, array): def to_numpy(self, array):
if np.isscalar(array): def _to_numpy(ary):
return array return ary.get(queue=self.queue)
return with_array_context(
self._rec_map_container(_to_numpy, array),
actx=None)
def freeze(self, array):
def _freeze(ary):
ary.finish()
return ary.with_queue(None)
return with_array_context(self._rec_map_container(_freeze, array), actx=None)
def thaw(self, array):
def _thaw(ary):
return ary.with_queue(self.queue)
return with_array_context(self._rec_map_container(_thaw, array), actx=self)
return array.get(queue=self.queue) def tag(self, tags: ToTagSetConvertible, array):
def _tag(ary):
return ary.tagged(tags)
return self._rec_map_container(_tag, array)
def tag_axis(self, iaxis: int, tags: ToTagSetConvertible, array):
def _tag_axis(ary):
return ary.with_tagged_axis(iaxis, tags)
return self._rec_map_container(_tag_axis, array)
def call_loopy(self, t_unit, **kwargs): def call_loopy(self, t_unit, **kwargs):
try: try:
t_unit = self._loopy_transform_cache[t_unit] executor = self._loopy_transform_cache[t_unit]
except KeyError: except KeyError:
orig_t_unit = t_unit orig_t_unit = t_unit
t_unit = self.transform_loopy_program(t_unit) executor = self.transform_loopy_program(t_unit).executor(self.context)
self._loopy_transform_cache[orig_t_unit] = t_unit self._loopy_transform_cache[orig_t_unit] = executor
del orig_t_unit del orig_t_unit
evt, result = t_unit(self.queue, **kwargs, allocator=self.allocator) evt, result = executor(self.queue, **kwargs, allocator=self.allocator)
if self._wait_event_queue_length is not False: if self._wait_event_queue_length is not False:
prg_name = t_unit.default_entrypoint.name prg_name = executor.t_unit.default_entrypoint.name
wait_event_queue = self._kernel_name_to_wait_event_queue.setdefault( wait_event_queue = self._kernel_name_to_wait_event_queue.setdefault(
prg_name, []) prg_name, [])
...@@ -186,27 +261,37 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -186,27 +261,37 @@ class PyOpenCLArrayContext(ArrayContext):
if len(wait_event_queue) > self._wait_event_queue_length: if len(wait_event_queue) > self._wait_event_queue_length:
wait_event_queue.pop(0).wait() wait_event_queue.pop(0).wait()
return result import arraycontext.impl.pyopencl.taggable_cl_array as tga
def freeze(self, array): # FIXME: Inherit loopy tags for these arrays
array.finish() return {name: tga.to_tagged_cl_array(ary) for name, ary in result.items()}
return array.with_queue(None)
def thaw(self, array): def clone(self):
return array.with_queue(self.queue) if self._passed_force_device_scalars:
return type(self)(self.queue, self.allocator,
wait_event_queue_length=self._wait_event_queue_length,
force_device_scalars=True)
else:
return type(self)(self.queue, self.allocator,
wait_event_queue_length=self._wait_event_queue_length)
# }}} # }}}
def transform_loopy_program(self, t_unit): # {{{ transform_loopy_program
def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
from warnings import warn from warnings import warn
warn("Using arraycontext.PyOpenCLArrayContext.transform_loopy_program " warn("Using the base "
"to transform a program. This is deprecated and will stop working " f"{type(self).__name__}.transform_loopy_program "
"in 2022. Instead, subclass PyOpenCLArrayContext and implement " "to transform a translation unit. "
"the specific logic required to transform the program for your " "This is largely a no-op and unlikely to result in fast generated "
"package or application. Check higher-level packages " "code."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want " "(e.g. meshmode), which may already have subclasses you may want "
"to build on.", "to build on.",
DeprecationWarning, stacklevel=2) UntransformedCodeWarning, stacklevel=2)
# accommodate loopy with and without kernel callables # accommodate loopy with and without kernel callables
...@@ -220,46 +305,17 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -220,46 +305,17 @@ class PyOpenCLArrayContext(ArrayContext):
"to create this kernel?") "to create this kernel?")
all_inames = default_entrypoint.all_inames() all_inames = default_entrypoint.all_inames()
# FIXME: This could be much smarter.
inner_iname = None
# import with underscore to avoid DeprecationWarning
from arraycontext.metadata import _FirstAxisIsElementsTag
if (len(default_entrypoint.instructions) == 1
and isinstance(default_entrypoint.instructions[0], lp.Assignment)
and any(isinstance(tag, _FirstAxisIsElementsTag)
# FIXME: Firedrake branch lacks kernel tags
for tag in getattr(default_entrypoint, "tags", ()))):
stmt, = default_entrypoint.instructions
out_inames = [v.name for v in stmt.assignee.index_tuple] inner_iname = None
assert out_inames
outer_iname = out_inames[0]
if len(out_inames) >= 2:
inner_iname = out_inames[1]
elif "iel" in all_inames:
outer_iname = "iel"
if "idof" in all_inames:
inner_iname = "idof"
elif "i0" in all_inames: if "i0" in all_inames:
outer_iname = "i0" outer_iname = "i0"
if "i1" in all_inames: if "i1" in all_inames:
inner_iname = "i1" inner_iname = "i1"
elif not all_inames:
# no loops, nothing to transform
return t_unit
else: else:
raise RuntimeError( return t_unit
"Unable to reason what outer_iname and inner_iname "
f"needs to be; all_inames is given as: {all_inames}"
)
if inner_iname is not None: if inner_iname is not None:
t_unit = lp.split_iname(t_unit, inner_iname, 16, inner_tag="l.0") t_unit = lp.split_iname(t_unit, inner_iname, 16, inner_tag="l.0")
...@@ -267,18 +323,9 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -267,18 +323,9 @@ class PyOpenCLArrayContext(ArrayContext):
return t_unit return t_unit
def tag(self, tags: Union[Sequence[Tag], Tag], array): # }}}
# Sorry, not capable.
return array
def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array):
# Sorry, not capable.
return array
def clone(self): # {{{ properties
return type(self)(self.queue, self.allocator,
wait_event_queue_length=self._wait_event_queue_length,
force_device_scalars=self._force_device_scalars)
@property @property
def permits_inplace_modification(self): def permits_inplace_modification(self):
...@@ -292,6 +339,8 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -292,6 +339,8 @@ class PyOpenCLArrayContext(ArrayContext):
def permits_advanced_indexing(self): def permits_advanced_indexing(self):
return False return False
# }}}
# }}} # }}}
# vim: foldmethod=marker # vim: foldmethod=marker