Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • inducer/arraycontext
  • kaushikcfd/arraycontext
  • fikl2/arraycontext
3 results
Show changes
Commits on Source (632)
Showing
with 4402 additions and 204 deletions
# https://editorconfig.org/
# https://github.com/editorconfig/editorconfig-vim
# https://github.com/editorconfig/editorconfig-emacs
root = true
[*]
indent_style = space
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.py]
indent_size = 4
[*.rst]
indent_size = 4
[*.cpp]
indent_size = 2
[*.hpp]
indent_size = 2
# There may be one in doc/
[Makefile]
indent_style = tab
# https://github.com/microsoft/vscode/issues/1679
[*.md]
trim_trailing_whitespace = false
version: 2
updates:
# Set update schedule for GitHub Actions
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
# vim: sw=4
......@@ -7,14 +7,15 @@ on:
jobs:
autopush:
name: Automatic push to gitlab.tiker.net
if: startsWith(github.repository, 'inducer/')
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- run: |
mkdir ~/.ssh && echo -e "Host gitlab.tiker.net\n\tStrictHostKeyChecking no\n" >> ~/.ssh/config
eval $(ssh-agent) && echo "$GITLAB_AUTOPUSH_KEY" | ssh-add -
git fetch --unshallow
git push "git@gitlab.tiker.net:inducer/$(basename $GITHUB_REPOSITORY).git" main
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
mirror_github_to_gitlab
env:
GITLAB_AUTOPUSH_KEY: ${{ secrets.GITLAB_AUTOPUSH_KEY }}
......
......@@ -7,181 +7,139 @@ on:
schedule:
- cron: '17 3 * * 0'
concurrency:
group: ${{ github.head_ref || github.ref_name }}
cancel-in-progress: true
jobs:
flake8:
name: Flake8
typos:
name: Typos
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: crate-ci/typos@master
ruff:
name: Ruff
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
-
uses: actions/setup-python@v1
with:
# matches compat target in setup.py
python-version: '3.6'
uses: actions/setup-python@v5
- name: "Main Script"
run: |
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh
. ./prepare-and-run-flake8.sh "$(basename $GITHUB_REPOSITORY)" test examples
pip install ruff
ruff check
pylint:
name: Pylint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: "Main Script"
run: |
sudo apt update
sudo apt install octave openmpi-bin libopenmpi-dev libhdf5-dev
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
echo "- mpi4py" >> $CONDA_ENVIRONMENT
echo "- scipy" >> $CONDA_ENVIRONMENT
echo "-------------------------------------------"
cat $CONDA_ENVIRONMENT
echo "-------------------------------------------"
USE_CONDA_BUILD=1
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh
. ./prepare-and-run-pylint.sh "$(basename $GITHUB_REPOSITORY)" examples/*.py test/test_*.py
pytest3:
name: Pytest Conda Py3
mypy:
name: Mypy
runs-on: ubuntu-latest
strategy:
matrix:
loopy-branch: [main, kernel_callables_v3-edit2]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: "Main Script"
run: |
sed -i "s/loopy.git/loopy.git@${{ matrix.loopy-branch }}/g" requirements.txt
sudo apt update
sudo apt install octave openmpi-bin libopenmpi-dev libhdf5-dev
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
export MPLBACKEND=Agg
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/ci-support.sh
. ./ci-support.sh
build_py_project_in_conda_env
with_echo python -m pip install mpi4py
test_py_project
python -m pip install mypy pytest
./run-mypy.sh
firedrake:
name: Pytest Firedrake
pytest3_pocl:
name: Pytest Conda Py3 POCL
runs-on: ubuntu-latest
container:
image: 'firedrakeproject/firedrake'
steps:
- uses: actions/checkout@v1
- name: "Dependencies"
run: |
. .ci/install-for-firedrake.sh
- name: "Test"
- uses: actions/checkout@v4
- name: "Main Script"
run: |
. /home/firedrake/firedrake/bin/activate
cd test
python -m pytest --tb=native -rxsw test_firedrake_interop.py
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_conda_env
test_py_project
firedrake_examples:
name: Examples Firedrake
pytest3_intel_cl:
name: Pytest Conda Py3 Intel
runs-on: ubuntu-latest
container:
image: 'firedrakeproject/firedrake'
steps:
- uses: actions/checkout@v1
- name: "Dependencies"
run: |
. .ci/install-for-firedrake.sh
- name: "Examples"
- uses: actions/checkout@v4
- name: "Main Script"
run: |
. /home/firedrake/firedrake/bin/activate
. ./.ci/run_firedrake_examples.sh
curl -L -O https://raw.githubusercontent.com/illinois-scicomp/machine-shop-maintenance/main/install-intel-icd.sh
sudo bash ./install-intel-icd.sh
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
echo "- ocl-icd-system" >> "$CONDA_ENVIRONMENT"
sed -i "/pocl/ d" "$CONDA_ENVIRONMENT"
export PYOPENCL_TEST=intel
source /opt/enable-intel-cl.sh
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_conda_env
test_py_project
examples3:
name: Examples Conda Py3
runs-on: ubuntu-latest
strategy:
matrix:
loopy-branch: [main, kernel_callables_v3-edit2]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: "Main Script"
run: |
sed -i "s/loopy.git/loopy.git@${{ matrix.loopy-branch }}/g" requirements.txt
export MPLBACKEND=Agg
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
USE_CONDA_BUILD=1
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/build-py-project-and-run-examples.sh
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-py-project-and-run-examples.sh
. ./build-py-project-and-run-examples.sh
docs:
name: Documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
-
uses: actions/setup-python@v1
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: "Main Script"
run: |
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/ci-support.sh
. ci-support.sh
curl -L -O https://tiker.net/ci-support-v0
. ci-support-v0
build_py_project_in_conda_env
conda install graphviz
# Work around
# intersphinx inventory 'https://firedrakeproject.org/objects.inv' not fetchable
# by deleting all the Firedrake stuff
rm -Rf meshmode/interop/firedrake
sed -i '/firedrakeproject/d' doc/conf.py
sed -i '/interop/d' doc/index.rst
rm doc/interop.rst
CI_SUPPORT_SPHINX_VERSION_SPECIFIER=">=4.0"
build_docs
downstream_tests:
strategy:
matrix:
downstream_project: [grudge, pytential, mirgecom]
downstream_project: [meshmode, grudge, mirgecom, mirgecom_examples]
fail-fast: false
name: Tests for downstream project ${{ matrix.downstream_project }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: "Main Script"
env:
DOWNSTREAM_PROJECT: ${{ matrix.downstream_project }}
run: |
if test "$DOWNSTREAM_PROJECT" = "mirgecom"; then
git clone "https://github.com/illinois-ceesd/$DOWNSTREAM_PROJECT.git"
else
git clone "https://github.com/inducer/$DOWNSTREAM_PROJECT.git"
fi
cd "$DOWNSTREAM_PROJECT"
echo "*** $DOWNSTREAM_PROJECT version: $(git rev-parse --short HEAD)"
sed -i "/egg=meshmode/ c git+file://$(readlink -f ..)#egg=meshmode" requirements.txt
export CONDA_ENVIRONMENT=.test-conda-env-py3.yml
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
test_downstream "$DOWNSTREAM_PROJECT"
# Avoid slow or complicated tests in downstream projects
export PYTEST_ADDOPTS="-k 'not (slowtest or octave or mpi)'"
if test "$DOWNSTREAM_PROJECT" = "mirgecom"; then
# can't turn off MPI in mirgecom
sudo apt-get update
sudo apt-get install openmpi-bin libopenmpi-dev
export CONDA_ENVIRONMENT=conda-env.yml
export CISUPPORT_PARALLEL_PYTEST=no
else
sed -i "/mpi4py/ d" requirements.txt
if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then
python ../examples/simple-dg.py --lazy
fi
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/ci-support.sh
. ./ci-support.sh
build_py_project_in_conda_env
test_py_project
# vim: sw=4
......@@ -21,3 +21,6 @@ a.out
.pytest_cache
test/nodal-dg
.pylintrc.yml
.run-pylint.py
Python 3 POCL:
script: |
sed -i "s/loopy.git/loopy.git@$LOOPY_BRANCH/g" requirements.txt
export PY_EXE=python3
export PYOPENCL_TEST=portable:pthread
# cython is here because pytential (for now, for TS) depends on it
export EXTRA_INSTALL="pybind11 cython numpy mako mpi4py oct2py"
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh
export PYOPENCL_TEST=portable:cpu
export EXTRA_INSTALL="jax[cpu]"
export JAX_PLATFORMS=cpu
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh
. ./build-and-test-py-project.sh
tags:
- python3
......@@ -16,20 +14,33 @@ Python 3 POCL:
artifacts:
reports:
junit: test/pytest.xml
parallel:
matrix:
- LOOPY_BRANCH: main
- LOOPY_BRANCH: kernel_callables_v3-edit2
Python 3 Nvidia Titan V:
script: |
sed -i "s/loopy.git/loopy.git@$LOOPY_BRANCH/g" requirements.txt
export PY_EXE=python3
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
export PYOPENCL_TEST=nvi:titan
export EXTRA_INSTALL="pybind11 cython numpy mako oct2py"
# cython is here because pytential (for now, for TS) depends on it
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh
. ./build-and-test-py-project.sh
build_py_project_in_venv
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
test_py_project
tags:
- python3
- nvidia-titan-v
except:
- tags
artifacts:
reports:
junit: test/pytest.xml
Python 3 POCL Nvidia Titan V:
script: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
export PYOPENCL_TEST=port:titan
build_py_project_in_venv
test_py_project
tags:
- python3
- nvidia-titan-v
......@@ -38,20 +49,12 @@ Python 3 Nvidia Titan V:
artifacts:
reports:
junit: test/pytest.xml
parallel:
matrix:
- LOOPY_BRANCH: main
- LOOPY_BRANCH: kernel_callables_v3-edit2
Python 3 POCL Examples:
script:
- sed -i "s/loopy.git/loopy.git@$LOOPY_BRANCH/g" requirements.txt
- test -n "$SKIP_EXAMPLES" && exit
- export PY_EXE=python3
- export PYOPENCL_TEST=portable:pthread
# cython is here because pytential (for now, for TS) depends on it
- export EXTRA_INSTALL="pybind11 cython numpy mako matplotlib"
- curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/build-py-project-and-run-examples.sh
- export PYOPENCL_TEST=portable:cpu
- curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-py-project-and-run-examples.sh
- ". ./build-py-project-and-run-examples.sh"
tags:
- python3
......@@ -59,41 +62,15 @@ Python 3 POCL Examples:
- large-node
except:
- tags
parallel:
matrix:
- LOOPY_BRANCH: main
- LOOPY_BRANCH: kernel_callables_v3-edit2
Python 3 POCL Firedrake:
tags:
- "docker-runner"
image: "firedrakeproject/firedrake"
script:
- . .ci/install-for-firedrake.sh
- cd test
- python -m pytest --tb=native --junitxml=pytest.xml -rxsw test_firedrake_interop.py
artifacts:
reports:
junit: test/pytest.xml
Python 3 POCL Firedrake Examples:
tags:
- "docker-runner"
image: "firedrakeproject/firedrake"
script:
- . .ci/install-for-firedrake.sh
- . ./.ci/run_firedrake_examples.sh
artifacts:
reports:
junit: test/pytest.xml
Python 3 Conda:
script: |
sed -i "s/loopy.git/loopy.git@$LOOPY_BRANCH/g" requirements.txt
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
export MPLBACKEND=Agg
export PYOPENCL_TEST=portable:cpu
# Avoid crashes like https://gitlab.tiker.net/inducer/arraycontext/-/jobs/536021
sed -i 's/jax/jax !=0.4.6/' .test-conda-env-py3.yml
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh
. ./build-and-test-py-project-within-miniconda.sh
tags:
# - docker-runner
......@@ -101,35 +78,60 @@ Python 3 Conda:
- large-node
except:
- tags
parallel:
matrix:
- LOOPY_BRANCH: main
- LOOPY_BRANCH: kernel_callables_v3-edit2
Documentation:
script:
- EXTRA_INSTALL="pybind11 cython numpy"
- curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh
- ". ./build-docs.sh"
script: |
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh
CI_SUPPORT_SPHINX_VERSION_SPECIFIER=">=4.0"
. ./build-docs.sh
tags:
- python3
Flake8:
Ruff:
script:
- curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh
- . ./prepare-and-run-flake8.sh "$CI_PROJECT_NAME" test examples
- pipx install ruff
- ruff check
tags:
- python3
- docker-runner
except:
- tags
Pylint:
script: |
export PY_EXE=python3
EXTRA_INSTALL="Cython pybind11 numpy mako matplotlib scipy mpi4py oct2py"
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh
EXTRA_INSTALL="jax[cpu]"
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh
. ./prepare-and-run-pylint.sh "$CI_PROJECT_NAME" examples/*.py test/test_*.py
tags:
- python3
except:
- tags
Mypy:
script: |
EXTRA_INSTALL="mypy pytest"
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_venv
./run-mypy.sh
tags:
- python3
except:
- tags
Downstream:
parallel:
matrix:
- DOWNSTREAM_PROJECT: [meshmode, grudge, mirgecom, mirgecom_examples]
tags:
- large-node
- "docker-runner"
script: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
test_downstream "$DOWNSTREAM_PROJECT"
if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then
python ../examples/simple-dg.py --lazy
fi
- arg: ignore
val:
- firedrake
- to_firedrake.py
- from_firedrake.py
- test_firedrake_interop.py
- arg: extension-pkg-whitelist
val: mayavi
name: test-conda-env
channels:
- conda-forge
- nodefaults
dependencies:
- python=3
- git
- libhwloc=2
- numpy
# pocl 3.1 required for full SVM functionality
- pocl>=3.1
- mako
- pyopencl
- islpy
- pip
- jax
arraycontext: Choose your favorite ``numpy``-workalike
======================================================
(Caution: vaporware for now! Much of this functionality exists in
`meshmode <https://documen.tician.de/meshmode/>`__ at the moment
and is in the process of being moved here)
.. image:: https://gitlab.tiker.net/inducer/arraycontext/badges/main/pipeline.svg
:alt: Gitlab Build Status
:target: https://gitlab.tiker.net/inducer/arraycontext/commits/main
.. image:: https://github.com/inducer/arraycontext/workflows/CI/badge.svg
:alt: Github Build Status
:target: https://github.com/inducer/arraycontext/actions?query=branch%3Amain+workflow%3ACI
.. image:: https://badge.fury.io/py/arraycontext.png
.. image:: https://badge.fury.io/py/arraycontext.svg
:alt: Python Package Index Release Page
:target: https://pypi.org/project/arraycontext/
* `Source code on Github <https://github.com/inducer/arraycontext>`_
* `Documentation <https://documen.tician.de/arraycontext>`_
GPU arrays? Deferred-evaluation arrays? Just plain ``numpy`` arrays? You'd like your
code to work with all of them? No problem! Comes with pre-made array context
implementations for:
- numpy
- `PyOpenCL <https://documen.tician.de/pyopencl/array.html>`__
- `Pytato <https://documen.tician.de/pytato>`__
- `JAX <https://jax.readthedocs.io/en/latest/>`__
- `Pytato <https://documen.tician.de/pytato>`__ (for lazy/deferred evaluation)
with backends for ``pyopencl`` and ``jax``.
- Debugging
- Profiling
``arraycontext`` started life as an array abstraction for use with the
``arraycontext`` started life as an array abstraction for use with the
`meshmode <https://documen.tician.de/meshmode/>`__ unstrucuted discretization
package.
Distributed under the MIT license.
Links
-----
* `Source code on Github <https://github.com/inducer/arraycontext>`_
* `Documentation <https://documen.tician.de/arraycontext>`_
"""
An array context is an abstraction that helps you dispatch between multiple
implementations of :mod:`numpy`-like :math:`n`-dimensional arrays.
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
......@@ -22,5 +29,177 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from .container import (
ArithArrayContainer,
ArrayContainer,
ArrayContainerT,
NotAnArrayContainerError,
SerializationKey,
SerializedContainer,
deserialize_container,
get_container_context_opt,
get_container_context_recursively,
get_container_context_recursively_opt,
is_array_container,
is_array_container_type,
register_multivector_as_array_container,
serialize_container,
)
from .container.arithmetic import (
BcastUntilActxArray,
with_container_arithmetic,
)
from .container.dataclass import dataclass_array_container
from .container.traversal import (
flat_size_and_dtype,
flatten,
freeze,
from_numpy,
map_array_container,
map_reduce_array_container,
mapped_over_array_containers,
multimap_array_container,
multimap_reduce_array_container,
multimapped_over_array_containers,
outer,
rec_map_array_container,
rec_map_reduce_array_container,
rec_multimap_array_container,
rec_multimap_reduce_array_container,
stringify_array_container_tree,
thaw,
to_numpy,
unflatten,
with_array_context,
)
from .context import (
Array,
ArrayContext,
ArrayOrArithContainer,
ArrayOrArithContainerOrScalar,
ArrayOrArithContainerOrScalarT,
ArrayOrArithContainerT,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ArrayOrContainerT,
ArrayT,
Scalar,
ScalarLike,
tag_axes,
)
from .impl.jax import EagerJAXArrayContext
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .loopy import make_loopy_program
from .pytest import (
PytestArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
pytest_generate_tests_for_array_contexts,
)
from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
__all__ = (
"ArithArrayContainer",
"Array",
"ArrayContainer",
"ArrayContainerT",
"ArrayContext",
"ArrayOrArithContainer",
"ArrayOrArithContainerOrScalar",
"ArrayOrArithContainerOrScalarT",
"ArrayOrArithContainerT",
"ArrayOrContainer",
"ArrayOrContainerOrScalar",
"ArrayOrContainerOrScalarT",
"ArrayOrContainerT",
"ArrayT",
"BcastUntilActxArray",
"CommonSubexpressionTag",
"EagerJAXArrayContext",
"ElementwiseMapKernelTag",
"NotAnArrayContainerError",
"NumpyArrayContext",
"PyOpenCLArrayContext",
"PytatoJAXArrayContext",
"PytatoPyOpenCLArrayContext",
"PytestArrayContextFactory",
"PytestPyOpenCLArrayContextFactory",
"Scalar",
"ScalarLike",
"SerializationKey",
"SerializedContainer",
"dataclass_array_container",
"deserialize_container",
"flat_size_and_dtype",
"flatten",
"freeze",
"from_numpy",
"get_container_context_opt",
"get_container_context_recursively",
"get_container_context_recursively_opt",
"is_array_container",
"is_array_container_type",
"make_loopy_program",
"map_array_container",
"map_reduce_array_container",
"mapped_over_array_containers",
"multimap_array_container",
"multimap_reduce_array_container",
"multimapped_over_array_containers",
"outer",
"pytest_generate_tests_for_array_contexts",
"rec_map_array_container",
"rec_map_reduce_array_container",
"rec_multimap_array_container",
"rec_multimap_reduce_array_container",
"register_multivector_as_array_container",
"serialize_container",
"stringify_array_container_tree",
"tag_axes",
"thaw",
"to_numpy",
"unflatten",
"with_array_context",
"with_container_arithmetic",
)
# {{{ deprecation handling
def _deprecated_acf():
"""A tiny undocumented function to pass to tests that take an ``actx_factory``
argument when running them from the command line.
"""
import pyopencl as cl
context = cl._csc()
queue = cl.CommandQueue(context)
return PyOpenCLArrayContext(queue)
_depr_name_to_replacement_and_obj = {
"get_container_context": (
"get_container_context_opt",
get_container_context_opt, 2022),
}
def __getattr__(name):
replacement_and_obj = _depr_name_to_replacement_and_obj.get(name)
if replacement_and_obj is not None:
replacement, obj, year = replacement_and_obj
from warnings import warn
warn(f"'arraycontext.{name}' is deprecated. "
f"Use '{replacement}' instead. "
f"'arraycontext.{name}' will continue to work until {year}.",
DeprecationWarning, stacklevel=2)
return obj
else:
raise AttributeError(name)
# }}}
# vim: foldmethod=marker
# mypy: disallow-untyped-defs
"""
.. currentmodule:: arraycontext
.. autoclass:: ArrayContainer
.. autoclass:: ArithArrayContainer
.. class:: ArrayContainerT
A type variable with a lower bound of :class:`ArrayContainer`.
.. autoexception:: NotAnArrayContainerError
Serialization/deserialization
-----------------------------
.. autoclass:: SerializationKey
.. autoclass:: SerializedContainer
.. autofunction:: is_array_container_type
.. autofunction:: serialize_container
.. autofunction:: deserialize_container
Context retrieval
-----------------
.. autofunction:: get_container_context_opt
.. autofunction:: get_container_context_recursively
.. autofunction:: get_container_context_recursively_opt
:class:`~pymbolic.geometric_algebra.MultiVector` support
---------------------------------------------------------
.. autofunction:: register_multivector_as_array_container
.. currentmodule:: arraycontext.container
Canonical locations for type annotations
----------------------------------------
.. class:: ArrayContainerT
:canonical: arraycontext.ArrayContainerT
.. class:: ArrayOrContainerT
:canonical: arraycontext.ArrayOrContainerT
.. class:: SerializationKey
:canonical: arraycontext.SerializationKey
.. class:: SerializedContainer
:canonical: arraycontext.SerializedContainer
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from collections.abc import Hashable, Sequence
from functools import singledispatch
from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar
# For use in singledispatch type annotations, because sphinx can't figure out
# what 'np' is.
import numpy
import numpy as np
from typing_extensions import Self
from arraycontext.context import ArrayContext, ArrayOrScalar
if TYPE_CHECKING:
from pymbolic.geometric_algebra import MultiVector
from arraycontext import ArrayOrContainer
# {{{ ArrayContainer
class ArrayContainer(Protocol):
"""
A protocol for generic containers of the array type supported by the
:class:`ArrayContext`.
The functionality required for the container to operated is supplied via
:func:`functools.singledispatch`. Implementations of the following functions need
to be registered for a type serving as an :class:`ArrayContainer`:
* :func:`serialize_container` for serialization, which gives the components
of the array.
* :func:`deserialize_container` for deserialization, which constructs a
container from a set of components.
* :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
a container, if it has one.
This allows enumeration of the component arrays in a container and the
construction of modified containers from an iterable of those component arrays.
Packages may register their own types as array containers. They must not
register other types (e.g. :class:`list`) as array containers.
The type :class:`numpy.ndarray` is considered an array container, but
only arrays with dtype *object* may be used as such. (This is so
because object arrays cannot be distinguished from non-object arrays
via their type.)
The container and its serialization interface has goals and uses
approaches similar to JAX's
`PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
however its implementation differs a bit.
.. note::
This class is used in type annotation and as a marker of array container
attributes for :func:`~arraycontext.dataclass_array_container`.
As a protocol, it is not intended as a superclass.
"""
# Array containers do not need to have any particular features, so this
# protocol is deliberately empty.
# This *is* used as a type annotation in dataclasses that are processed
# by dataclass_array_container, where it's used to recognize attributes
# that are container-typed.
class ArithArrayContainer(ArrayContainer, Protocol):
"""
A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
"""
# This is loose and permissive, assuming that any array can be added
# to any container. The alternative would be to plaster type-ignores
# on all those uses. Achieving typing precision on what broadcasting is
# allowable seems like a huge endeavor and is likely not feasible without
# a mypy plugin. Maybe some day? -AK, November 2024
def __neg__(self) -> Self: ...
def __abs__(self) -> Self: ...
def __add__(self, other: ArrayOrScalar | Self) -> Self: ...
def __radd__(self, other: ArrayOrScalar | Self) -> Self: ...
def __sub__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ...
def __mul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ...
def __pow__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...
ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
class NotAnArrayContainerError(TypeError):
""":class:`TypeError` subclass raised when an array container is expected."""
SerializationKey: TypeAlias = Hashable
SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]]
@singledispatch
def serialize_container(
ary: ArrayContainer) -> SerializedContainer:
r"""Serialize the array container into a sequence over its components.
The order of the components and their identifiers are entirely under
the control of the container class. However, the order is required to be
deterministic, i.e. two calls to :func:`serialize_container` on
array containers of the same types with the same number of
sub-arrays must result in a sequence with the keys in the same
order.
If *ary* is mutable, the serialization function is not required to ensure
that the serialization result reflects the array state at the time of the
call to :func:`serialize_container`.
:returns: a :class:`Sequence` of 2-tuples where the first
entry is an identifier for the component and the second entry
is an array-like component of the :class:`ArrayContainer`.
Components can themselves be :class:`ArrayContainer`\ s, allowing
for arbitrarily nested structures. The identifiers need to be hashable
but are otherwise treated as opaque.
"""
raise NotAnArrayContainerError(
f"'{type(ary).__name__}' cannot be serialized as a container")
@singledispatch
def deserialize_container(
template: ArrayContainerT,
serialized: SerializedContainer) -> ArrayContainerT:
"""Deserialize a sequence into an array container following a *template*.
:param template: an instance of an existing object that
can be used to aid in the deserialization. For a similar choice
see :attr:`~numpy.class.__array_finalize__`.
:param serialized: a sequence that mirrors the output of
:meth:`serialize_container`.
"""
raise NotAnArrayContainerError(
f"'{type(template).__name__}' cannot be deserialized as a container")
def is_array_container_type(cls: type) -> bool:
"""
:returns: *True* if the type *cls* has a registered implementation of
:func:`serialize_container`, or if it is an :class:`ArrayContainer`.
.. warning::
Not all instances of a type that this function labels an array container
must automatically be array containers. For example, while this
function will say that :class:`numpy.ndarray` is an array container
type, only object arrays *actually are* array containers.
"""
assert isinstance(cls, type), f"must pass a {type!r}, not a '{cls!r}'"
return (
cls is ArrayContainer
or (serialize_container.dispatch(cls)
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]
def is_array_container(ary: object) -> bool:
"""
:returns: *True* if the instance *ary* has a registered implementation of
:func:`serialize_container`.
"""
from warnings import warn
warn("is_array_container is deprecated and will be removed in 2022. "
"If you must know precisely whether something is an array container, "
"try serializing it and catch NotAnArrayContainerError. For a "
"cheaper option, see is_array_container_type.",
DeprecationWarning, stacklevel=2)
return (serialize_container.dispatch(ary.__class__)
is not serialize_container.__wrapped__ # type:ignore[attr-defined]
# numpy values with scalar elements aren't array containers
and not (isinstance(ary, np.ndarray)
and ary.dtype.kind != "O")
)
@singledispatch
def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None:
"""Retrieves the :class:`ArrayContext` from the container, if any.
This function is not recursive, so it will only search at the root level
of the container. For the recursive version, see
:func:`get_container_context_recursively`.
"""
return getattr(ary, "array_context", None)
# }}}
# {{{ object arrays as array containers
@serialize_container.register(np.ndarray)
def _serialize_ndarray_container(
ary: numpy.ndarray) -> SerializedContainer:
if ary.dtype.char != "O":
raise NotAnArrayContainerError(
f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'")
# special-cased for speed
if ary.ndim == 1:
return [(i, ary[i]) for i in range(ary.shape[0])]
elif ary.ndim == 2:
return [((i, j), ary[i, j])
for i in range(ary.shape[0])
for j in range(ary.shape[1])
]
else:
return list(np.ndenumerate(ary))
@deserialize_container.register(np.ndarray)
# https://github.com/python/mypy/issues/13040
def _deserialize_ndarray_container( # type: ignore[misc]
template: numpy.ndarray,
serialized: SerializedContainer) -> numpy.ndarray:
# disallow subclasses
assert type(template) is np.ndarray
assert template.dtype.char == "O"
result = type(template)(template.shape, dtype=object)
for i, subary in serialized:
# FIXME: numpy annotations don't seem to handle object arrays very well
result[i] = subary # type: ignore[call-overload]
return result
# }}}
# {{{ get_container_context_recursively
def get_container_context_recursively_opt(
ary: ArrayContainer) -> ArrayContext | None:
"""Walks the :class:`ArrayContainer` hierarchy to find an
:class:`ArrayContext` associated with it.
If different components that have different array contexts are found at
any level, an assertion error is raised.
Returns *None* if no array context was found.
"""
# try getting the array context directly
actx = get_container_context_opt(ary)
if actx is not None:
return actx
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError:
return actx
else:
for _, subary in iterable:
context = get_container_context_recursively_opt(subary)
if context is None:
continue
if not __debug__:
return context
elif actx is None:
actx = context
else:
assert actx is context
return actx
def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None:
"""Walks the :class:`ArrayContainer` hierarchy to find an
:class:`ArrayContext` associated with it.
If different components that have different array contexts are found at
any level, an assertion error is raised.
Raises an error if no array container is found.
"""
actx = get_container_context_recursively_opt(ary)
if actx is None:
# raise ValueError("no array context was found")
from warnings import warn
warn("No array context was found. This will be an error starting in "
"July of 2022. If you would like the function to return "
"None if no array context was found, use "
"get_container_context_recursively_opt.",
DeprecationWarning, stacklevel=2)
return actx
# }}}
# {{{ MultiVector support, see pymbolic.geometric_algebra
# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic.
# (Though clearly there exists a dependency via loopy.)
def _serialize_multivec_as_container(mv: MultiVector) -> SerializedContainer:
return list(mv.data.items())
# FIXME: Ignored due to https://github.com/python/mypy/issues/13040
def _deserialize_multivec_as_container( # type: ignore[misc]
template: MultiVector,
serialized: SerializedContainer) -> MultiVector:
from pymbolic.geometric_algebra import MultiVector
return MultiVector(dict(serialized), space=template.space)
def _get_container_context_opt_from_multivec(mv: MultiVector) -> None:
return None
def register_multivector_as_array_container() -> None:
"""Registers :class:`~pymbolic.geometric_algebra.MultiVector` as an
:class:`ArrayContainer`. This function may be called multiple times. The
second and subsequent calls have no effect.
"""
from pymbolic.geometric_algebra import MultiVector
if MultiVector not in serialize_container.registry:
serialize_container.register(MultiVector)(_serialize_multivec_as_container)
deserialize_container.register(MultiVector)(
_deserialize_multivec_as_container)
get_container_context_opt.register(MultiVector)(
_get_container_context_opt_from_multivec)
assert MultiVector in serialize_container.registry
# }}}
# vim: foldmethod=marker
# mypy: disallow-untyped-defs
from __future__ import annotations
__doc__ = """
.. currentmodule:: arraycontext
.. autofunction:: with_container_arithmetic
.. autoclass:: BcastUntilActxArray
"""
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import enum
import operator
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import partialmethod
from numbers import Number
from typing import Any, TypeVar
from warnings import warn
import numpy as np
from arraycontext.container import (
NotAnArrayContainerError,
deserialize_container,
serialize_container,
)
from arraycontext.context import ArrayContext, ArrayOrContainer
# {{{ with_container_arithmetic
T = TypeVar("T")
@enum.unique
class _OpClass(enum.Enum):
ARITHMETIC = enum.auto()
MATMUL = enum.auto()
BITWISE = enum.auto()
SHIFT = enum.auto()
EQ_COMPARISON = enum.auto()
REL_COMPARISON = enum.auto()
_UNARY_OP_AND_DUNDER = [
("pos", "+{}", _OpClass.ARITHMETIC),
("neg", "-{}", _OpClass.ARITHMETIC),
("abs", "abs({})", _OpClass.ARITHMETIC),
("inv", "~{}", _OpClass.BITWISE),
]
_BINARY_OP_AND_DUNDER = [
("add", "{} + {}", True, _OpClass.ARITHMETIC),
("sub", "{} - {}", True, _OpClass.ARITHMETIC),
("mul", "{} * {}", True, _OpClass.ARITHMETIC),
("truediv", "{} / {}", True, _OpClass.ARITHMETIC),
("floordiv", "{} // {}", True, _OpClass.ARITHMETIC),
("pow", "{} ** {}", True, _OpClass.ARITHMETIC),
("mod", "{} % {}", True, _OpClass.ARITHMETIC),
("divmod", "divmod({}, {})", True, _OpClass.ARITHMETIC),
("matmul", "{} @ {}", True, _OpClass.MATMUL),
("and", "{} & {}", True, _OpClass.BITWISE),
("or", "{} | {}", True, _OpClass.BITWISE),
("xor", "{} ^ {}", True, _OpClass.BITWISE),
("lshift", "{} << {}", False, _OpClass.SHIFT),
("rshift", "{} >> {}", False, _OpClass.SHIFT),
("eq", "{} == {}", False, _OpClass.EQ_COMPARISON),
("ne", "{} != {}", False, _OpClass.EQ_COMPARISON),
("lt", "{} < {}", False, _OpClass.REL_COMPARISON),
("gt", "{} > {}", False, _OpClass.REL_COMPARISON),
("le", "{} <= {}", False, _OpClass.REL_COMPARISON),
("ge", "{} >= {}", False, _OpClass.REL_COMPARISON),
]
def _format_unary_op_str(op_str: str, arg1: tuple[str, ...] | str) -> str:
if isinstance(arg1, tuple):
arg1_entry, arg1_container = arg1
return (f"{op_str.format(arg1_entry)} "
f"for {arg1_entry} in {arg1_container}")
else:
return op_str.format(arg1)
def _format_binary_op_str(op_str: str,
arg1: tuple[str, str] | str,
arg2: tuple[str, str] | str) -> str:
if isinstance(arg1, tuple) and isinstance(arg2, tuple):
arg1_entry, arg1_container = arg1
arg2_entry, arg2_container = arg2
return (f"{op_str.format(arg1_entry, arg2_entry)} "
f"for {arg1_entry}, {arg2_entry} "
f"in zip({arg1_container}, {arg2_container}, strict=__debug__)")
elif isinstance(arg1, tuple):
arg1_entry, arg1_container = arg1
return (f"{op_str.format(arg1_entry, arg2)} "
f"for {arg1_entry} in {arg1_container}")
elif isinstance(arg2, tuple):
arg2_entry, arg2_container = arg2
return (f"{op_str.format(arg1, arg2_entry)} "
f"for {arg2_entry} in {arg2_container}")
else:
return op_str.format(arg1, arg2)
class NumpyObjectArrayMetaclass(type):
def __instancecheck__(cls, instance: Any) -> bool:
return isinstance(instance, np.ndarray) and instance.dtype == object
class NumpyObjectArray(metaclass=NumpyObjectArrayMetaclass):
pass
class ComplainingNumpyNonObjectArrayMetaclass(type):
def __instancecheck__(cls, instance: Any) -> bool:
if isinstance(instance, np.ndarray) and instance.dtype != object:
# Example usage site:
# https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149
# where normal is passed in by test_lfr_flux as a 'custom-made'
# numpy array of dtype float64.
warn(
"Broadcasting container against non-object numpy array. "
"This was never documented to work and will now stop working in "
"2025. Convert the array to an object array or use "
"arraycontext.BcastUntilActxArray (or similar) to obtain the desired "
"broadcasting semantics.", DeprecationWarning, stacklevel=3)
return True
else:
return False
class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMetaclass):
pass
def with_container_arithmetic(
*,
number_bcasts_across: bool | None = None,
bcasts_across_obj_array: bool | None = None,
container_types_bcast_across: tuple[type, ...] | None = None,
arithmetic: bool = True,
matmul: bool = False,
bitwise: bool = False,
shift: bool = False,
_cls_has_array_context_attr: bool | None = None,
eq_comparison: bool | None = None,
rel_comparison: bool | None = None,
# deprecated:
bcast_number: bool | None = None,
bcast_obj_array: bool | None = None,
bcast_numpy_array: bool = False,
_bcast_actx_array_type: bool | None = None,
bcast_container_types: tuple[type, ...] | None = None,
) -> Callable[[type], type]:
"""A class decorator that implements built-in operators for array containers
by propagating the operations to the elements of the container.
:arg number_bcasts_across: If *True*, numbers broadcast over the container
(with the container as the 'outer' structure).
:arg bcasts_across_obj_array: If *True*, this container will be broadcast
across :mod:`numpy` object arrays
(with the object array as the 'outer' structure).
Add :class:`numpy.ndarray` to *container_types_bcast_across* to achieve
the 'reverse' broadcasting.
:arg container_types_bcast_across: A sequence of container types that will broadcast
across this container, with this container as the 'outer' structure.
:class:`numpy.ndarray` is permitted to be part of this sequence to
indicate that object arrays (and *only* object arrays) will be broadcast.
In this case, *bcasts_across_obj_array* must be *False*.
:arg arithmetic: Implement the conventional arithmetic operators, including
``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as
:func:`abs`.
:arg bitwise: If *True*, implement bitwise and, or, not, and inversion.
:arg shift: If *True*, implement bit shifts.
:arg eq_comparison: If *True*, implement ``==`` and ``!=``.
:arg rel_comparison: If *True*, implement ``<``, ``<=``, ``>``, ``>=``.
In that case, if *eq_comparison* is unspecified, it is also set to
*True*.
:arg _cls_has_array_context_attr: A flag indicating whether the decorated
class has an ``array_context`` attribute. If so, and if :data:`__debug__`
is *True*, an additional check is performed in binary operators
to ensure that both containers use the same array context.
If *None* (the default), this value is set based on whether the class
has an ``array_context`` attribute.
Consider this argument an unstable interface. It may disappear at any moment.
Each operator class also includes the "reverse" operators if applicable.
.. note::
For the generated binary arithmetic operators, if certain types
should be broadcast over the container (with the container as the
'outer' structure) but are not handled in this way by their types,
you may wrap them in :class:`BcastUntilActxArray` to achieve
the desired semantics.
.. note::
To generate the code implementing the operators, this function relies on
class methods ``_deserialize_init_arrays_code`` and
``_serialize_init_arrays_code``. This interface should be considered
undocumented and subject to change, however if you are curious, you may look
at its implementation in :class:`meshmode.dof_array.DOFArray`. For a simple
structure type, the implementation might look like this::
@classmethod
def _serialize_init_arrays_code(cls, instance_name):
return {"u": f"{instance_name}.u", "v": f"{instance_name}.v"}
@classmethod
def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
return f"u={args['u']}, v={args['v']}"
:func:`dataclass_array_container` automatically generates an appropriate
implementation of these methods, so :func:`with_container_arithmetic`
should nest "outside" :func:dataclass_array_container`.
"""
# Hard-won design lessons:
#
# - Anything that special-cases np.ndarray by type is broken by design because:
# - np.ndarray is an array context array.
# - numpy object arrays can be array containers.
# Using NumpyObjectArray and NumpyNonObjectArray *may* be better?
# They're new, so there is no operational experience with them.
#
# - Broadcast rules are hard to change once established, particularly
# because one cannot grep for their use.
# {{{ handle inputs
if rel_comparison and eq_comparison is None:
eq_comparison = True
if eq_comparison is None:
raise TypeError("eq_comparison must be specified")
# {{{ handle bcast_number
if bcast_number is not None:
if number_bcasts_across is not None:
raise TypeError(
"may specify at most one of 'bcast_number' and "
"'number_bcasts_across'")
warn("'bcast_number' is deprecated and will be unsupported from 2025. "
"Use 'number_bcasts_across', with equivalent meaning.",
DeprecationWarning, stacklevel=2)
number_bcasts_across = bcast_number
else:
if number_bcasts_across is None:
number_bcasts_across = True
del bcast_number
# }}}
# {{{ handle bcast_obj_array
if bcast_obj_array is not None:
if bcasts_across_obj_array is not None:
raise TypeError(
"may specify at most one of 'bcast_obj_array' and "
"'bcasts_across_obj_array'")
warn("'bcast_obj_array' is deprecated and will be unsupported from 2025. "
"Use 'bcasts_across_obj_array', with equivalent meaning.",
DeprecationWarning, stacklevel=2)
bcasts_across_obj_array = bcast_obj_array
else:
if bcasts_across_obj_array is None:
raise TypeError("bcasts_across_obj_array must be specified")
del bcast_obj_array
# }}}
# {{{ handle bcast_container_types
if bcast_container_types is not None:
if container_types_bcast_across is not None:
raise TypeError(
"may specify at most one of 'bcast_container_types' and "
"'container_types_bcast_across'")
warn("'bcast_container_types' is deprecated and will be unsupported from 2025. "
"Use 'container_types_bcast_across', with equivalent meaning.",
DeprecationWarning, stacklevel=2)
container_types_bcast_across = bcast_container_types
else:
if container_types_bcast_across is None:
container_types_bcast_across = ()
del bcast_container_types
# }}}
if rel_comparison is None:
raise TypeError("rel_comparison must be specified")
if bcast_numpy_array:
warn("'bcast_numpy_array=True' is deprecated and will be unsupported"
" from 2025.", DeprecationWarning, stacklevel=2)
if _bcast_actx_array_type:
raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'"
" cannot be both set.")
if not bcasts_across_obj_array and bcast_numpy_array:
raise TypeError("bcast_obj_array must be set if bcast_numpy_array is")
if bcast_numpy_array:
def numpy_pred(name: str) -> str:
return f"is_numpy_array({name})"
elif bcasts_across_obj_array:
def numpy_pred(name: str) -> str:
return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'"
else:
def numpy_pred(name: str) -> str:
return "False" # optimized away
if np.ndarray in container_types_bcast_across and bcasts_across_obj_array:
raise ValueError("If numpy.ndarray is part of bcast_container_types, "
"bcast_obj_array must be False.")
numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray]
container_types_bcast_across = tuple(
new_ct
for old_ct in container_types_bcast_across
for new_ct in
(numpy_check_types
if old_ct is np.ndarray
else [old_ct])
)
desired_op_classes = set()
if arithmetic:
desired_op_classes.add(_OpClass.ARITHMETIC)
if matmul:
desired_op_classes.add(_OpClass.MATMUL)
if bitwise:
desired_op_classes.add(_OpClass.BITWISE)
if shift:
desired_op_classes.add(_OpClass.SHIFT)
if eq_comparison:
desired_op_classes.add(_OpClass.EQ_COMPARISON)
if rel_comparison:
desired_op_classes.add(_OpClass.REL_COMPARISON)
# }}}
def wrap(cls: Any) -> Any:
if not hasattr(cls, "__array_ufunc__"):
warn(f"{cls} does not have __array_ufunc__ set. "
"This will cause numpy to attempt broadcasting, in a way that "
"is likely undesired. "
f"To avoid this, set __array_ufunc__ = None in {cls}.",
stacklevel=2)
cls_has_array_context_attr: bool | None = _cls_has_array_context_attr
bcast_actx_array_type: bool | None = _bcast_actx_array_type
if cls_has_array_context_attr is None and hasattr(cls, "array_context"):
raise TypeError(
f"{cls} has an 'array_context' attribute, but it does not "
"set '_cls_has_array_context_attr' to *True* when calling "
"with_container_arithmetic. This is being interpreted "
"as '.array_context' being permitted to fail "
"with an exception, which is no longer allowed. "
f"If {cls.__name__}.array_context will not fail, pass "
"'_cls_has_array_context_attr=True'. "
"If you do not want container arithmetic to make "
"use of the array context, set "
"'_cls_has_array_context_attr=False'.")
if bcast_actx_array_type is None:
if cls_has_array_context_attr:
if number_bcasts_across:
bcast_actx_array_type = cls_has_array_context_attr
else:
bcast_actx_array_type = False
else:
if bcast_actx_array_type and not cls_has_array_context_attr:
raise TypeError("_bcast_actx_array_type can be True only if "
"_cls_has_array_context_attr is set.")
if bcast_actx_array_type:
if _bcast_actx_array_type:
warn(
f"Broadcasting array context array types across {cls} "
"has been explicitly "
"enabled. As of 2026, this will stop working. "
"Use arraycontext.Bcast* object wrappers for "
"roughly equivalent functionality. "
"See the discussion in "
"https://github.com/inducer/arraycontext/pull/190. "
"To opt out now (and avoid this warning), "
"pass _bcast_actx_array_type=False. ",
DeprecationWarning, stacklevel=2)
else:
warn(
f"Broadcasting array context array types across {cls} "
"has been implicitly "
"enabled. As of 2026, this will no longer work. "
"Use arraycontext.Bcast* object wrappers for "
"roughly equivalent functionality. "
"See the discussion in "
"https://github.com/inducer/arraycontext/pull/190. "
"To opt out now (and avoid this warning), "
"pass _bcast_actx_array_type=False.",
DeprecationWarning, stacklevel=2)
if (not hasattr(cls, "_serialize_init_arrays_code")
or not hasattr(cls, "_deserialize_init_arrays_code")):
raise TypeError(f"class '{cls.__name__}' must provide serialization "
"code to generate arithmetic operations by implementing "
"'_serialize_init_arrays_code' and "
"'_deserialize_init_arrays_code'. If this is a dataclass, "
"use the 'dataclass_array_container' decorator first.")
from pytools.codegen import CodeGenerator, Indentation
gen = CodeGenerator()
gen(f"""
from numbers import Number
import numpy as np
from arraycontext import ArrayContainer
from warnings import warn
def _raise_if_actx_none(actx):
if actx is None:
raise ValueError("array containers with frozen arrays "
"cannot be operated upon")
return actx
def is_numpy_array(arg):
if isinstance(arg, np.ndarray):
if arg.dtype != "O":
warn("Operand is a non-object numpy array, "
"and the broadcasting behavior of this array container "
"({cls}) "
"is influenced by this because of its use of "
"the deprecated bcast_numpy_array. This broadcasting "
"behavior will change in 2025. If you would like the "
"broadcasting behavior to stay the same, make sure "
"to convert the passed numpy array to an "
"object array.",
DeprecationWarning, stacklevel=3)
return True
else:
return False
""")
gen("")
if container_types_bcast_across:
for i, bct in enumerate(container_types_bcast_across):
gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}")
gen("")
container_type_names_bcast_across = tuple(
f"_bctype{i}" for i in range(len(container_types_bcast_across)))
if number_bcasts_across:
container_type_names_bcast_across += ("Number",)
def same_key(k1: T, k2: T) -> T:
assert k1 == k2
return k1
def tup_str(t: tuple[str, ...]) -> str:
if not t:
return "()"
else:
return "({},)".format(", ".join(t))
gen(f"cls._outer_bcast_types = {tup_str(container_type_names_bcast_across)}")
gen("cls._container_types_bcast_across = "
f"{tup_str(container_type_names_bcast_across)}")
gen(f"cls._bcast_numpy_array = {bcast_numpy_array}")
gen(f"cls._bcast_obj_array = {bcasts_across_obj_array}")
gen(f"cls._bcasts_across_obj_array = {bcasts_across_obj_array}")
gen("")
# {{{ unary operators
for dunder_name, op_str, op_cls in _UNARY_OP_AND_DUNDER:
if op_cls not in desired_op_classes:
continue
fname = f"_{cls.__name__.lower()}_{dunder_name}"
init_args = cls._deserialize_init_arrays_code("arg1", {
key_arg1: _format_unary_op_str(op_str, expr_arg1)
for key_arg1, expr_arg1 in
cls._serialize_init_arrays_code("arg1").items()
})
gen(f"""
def {fname}(arg1):
return cls({init_args})
cls.__{dunder_name}__ = {fname}""")
gen("")
# }}}
# {{{ binary operators
for dunder_name, op_str, reversible, op_cls in _BINARY_OP_AND_DUNDER:
fname = f"_{cls.__name__.lower()}_{dunder_name}"
if op_cls not in desired_op_classes:
# Leaving equality comparison at the default supplied by
# dataclasses is dangerous: Comparison of dataclass fields
# might return an array of truth values, and the dataclasses
# implementation of __eq__ might consider that 'truthy' enough,
# yielding bogus equality results.
if op_cls == _OpClass.EQ_COMPARISON:
gen(f"def {fname}(arg1, arg2):")
with Indentation(gen):
gen("return NotImplemented")
gen(f"cls.__{dunder_name}__ = {fname}")
gen("")
continue
zip_init_args = cls._deserialize_init_arrays_code("arg1", {
same_key(key_arg1, key_arg2):
_format_binary_op_str(op_str, expr_arg1, expr_arg2)
for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip(
cls._serialize_init_arrays_code("arg1").items(),
cls._serialize_init_arrays_code("arg2").items(),
strict=True)
})
bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", {
key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
for key_arg1, expr_arg1 in
cls._serialize_init_arrays_code("arg1").items()
})
bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", {
key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2)
for key_arg2, expr_arg2 in
cls._serialize_init_arrays_code("arg2").items()
})
# {{{ "forward" binary operators
gen(f"def {fname}(arg1, arg2):")
with Indentation(gen):
gen("if arg2.__class__ is cls:")
with Indentation(gen):
if __debug__ and cls_has_array_context_attr:
gen("""
arg1_actx = arg1.array_context
arg2_actx = arg2.array_context
if arg1_actx is not arg2_actx:
msg = ("array contexts of both arguments "
"must match")
if arg1_actx is None:
raise ValueError(msg
+ ": left operand is frozen "
"(i.e. has no array context)")
elif arg2_actx is None:
raise ValueError(msg
+ ": right operand is frozen "
"(i.e. has no array context)")
else:
raise ValueError(msg)""")
gen(f"return cls({zip_init_args})")
if bcast_actx_array_type:
if __debug__:
bcast_actx_ary_types: tuple[str, ...] = (
"*_raise_if_actx_none("
"arg1.array_context).array_types",)
else:
bcast_actx_ary_types = (
"*arg1.array_context.array_types",)
else:
bcast_actx_ary_types = ()
gen(f"""
if {numpy_pred("arg2")}:
result = np.empty_like(arg2, dtype=object)
for i in np.ndindex(arg2.shape):
result[i] = {op_str.format("arg1", "arg2[i]")}
return result
if {bool(container_type_names_bcast_across)}: # optimized away
if isinstance(arg2,
{tup_str(container_type_names_bcast_across
+ bcast_actx_ary_types)}):
if __debug__:
if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
warn("Broadcasting {cls} over array "
f"context array type {{type(arg2)}} is deprecated "
"and will no longer work in 2026. "
"Use arraycontext.Bcast* object wrappers for "
"roughly equivalent functionality. "
"See the discussion in "
"https://github.com/inducer/arraycontext/"
"pull/190. ",
DeprecationWarning, stacklevel=2)
return cls({bcast_init_args_arg1_is_outer})
return NotImplemented
""")
gen(f"cls.__{dunder_name}__ = {fname}")
gen("")
# }}}
# {{{ "reverse" binary operators
if reversible:
fname = f"_{cls.__name__.lower()}_r{dunder_name}"
if bcast_actx_array_type:
if __debug__:
bcast_actx_ary_types = (
"*_raise_if_actx_none("
"arg2.array_context).array_types",)
else:
bcast_actx_ary_types = (
"*arg2.array_context.array_types",)
else:
bcast_actx_ary_types = ()
gen(f"""
def {fname}(arg2, arg1):
# assert other.__cls__ is not cls
if {numpy_pred("arg1")}:
result = np.empty_like(arg1, dtype=object)
for i in np.ndindex(arg1.shape):
result[i] = {op_str.format("arg1[i]", "arg2")}
return result
if {bool(container_type_names_bcast_across)}: # optimized away
if isinstance(arg1,
{tup_str(container_type_names_bcast_across
+ bcast_actx_ary_types)}):
if __debug__:
if isinstance(arg1,
{tup_str(bcast_actx_ary_types)}):
warn("Broadcasting {cls} over array "
f"context array type {{type(arg1)}} "
"is deprecated "
"and will no longer work in 2026."
"Use arraycontext.Bcast* object "
"wrappers for roughly equivalent "
"functionality. "
"See the discussion in "
"https://github.com/inducer/arraycontext/"
"pull/190. ",
DeprecationWarning, stacklevel=2)
return cls({bcast_init_args_arg2_is_outer})
return NotImplemented
cls.__r{dunder_name}__ = {fname}""")
gen("")
# }}}
# }}}
# This will evaluate the module, which is all we need.
code = gen.get().rstrip()+"\n"
result_dict = {"_MODULE_SOURCE_CODE": code, "cls": cls}
exec(compile(code, f"<container arithmetic for {cls.__name__}>", "exec"),
result_dict)
return cls
# we're being called as @with_container_arithmetic(...), with parens
return wrap
# }}}
# {{{ Bcast object-ified broadcast rules
# Possible advantages of the "Bcast" broadcast-rule-as-object design:
#
# - If one rule does not fit the user's need, they can straightforwardly use
# another.
#
# - It's straightforward to find where certain broadcast rules are used.
#
# - The broadcast rule can contain more state. For example, it's now easy
# for the rule to know what array context should be used to determine
# actx array types.
#
# Possible downsides of the "Bcast" broadcast-rule-as-object design:
#
# - User code is a bit more wordy.
@dataclass(frozen=True)
class BcastUntilActxArray:
"""
An operator-overloading wrapper around an object (*broadcastee*) that should be
broadcast across array containers until the 'opposite' operand is one of the
:attr:`~arraycontext.ArrayContext.array_types`
of *actx* or a :class:`~numbers.Number`.
Suggested usage pattern::
bcast = functools.partial(BcastUntilActxArray, actx)
container + bcast(actx_array)
.. automethod:: __init__
"""
array_context: ArrayContext
broadcastee: ArrayOrContainer
_stop_types: tuple[type, ...] = field(init=False)
def __post_init__(self) -> None:
object.__setattr__(
self, "_stop_types", (*self.array_context.array_types, Number))
def _binary_op(self,
op: Callable[
[ArrayOrContainer, ArrayOrContainer],
ArrayOrContainer
],
right: ArrayOrContainer
) -> ArrayOrContainer:
try:
serialized = serialize_container(right)
except NotAnArrayContainerError:
return op(self.broadcastee, right)
return deserialize_container(right, [
(k, op(self.broadcastee, right_v)
if isinstance(right_v, self._stop_types) else
self._binary_op(op, right_v)
)
for k, right_v in serialized])
def _rev_binary_op(self,
op: Callable[
[ArrayOrContainer, ArrayOrContainer],
ArrayOrContainer
],
left: ArrayOrContainer
) -> ArrayOrContainer:
try:
serialized = serialize_container(left)
except NotAnArrayContainerError:
return op(left, self.broadcastee)
return deserialize_container(left, [
(k, op(left_v, self.broadcastee)
if isinstance(left_v, self._stop_types) else
self._rev_binary_op(op, left_v)
)
for k, left_v in serialized])
__add__ = partialmethod(_binary_op, operator.add)
__radd__ = partialmethod(_rev_binary_op, operator.add)
__sub__ = partialmethod(_binary_op, operator.sub)
__rsub__ = partialmethod(_rev_binary_op, operator.sub)
__mul__ = partialmethod(_binary_op, operator.mul)
__rmul__ = partialmethod(_rev_binary_op, operator.mul)
__truediv__ = partialmethod(_binary_op, operator.truediv)
__rtruediv__ = partialmethod(_rev_binary_op, operator.truediv)
__floordiv__ = partialmethod(_binary_op, operator.floordiv)
__rfloordiv__ = partialmethod(_rev_binary_op, operator.floordiv)
__mod__ = partialmethod(_binary_op, operator.mod)
__rmod__ = partialmethod(_rev_binary_op, operator.mod)
__pow__ = partialmethod(_binary_op, operator.pow)
__rpow__ = partialmethod(_rev_binary_op, operator.pow)
__lshift__ = partialmethod(_binary_op, operator.lshift)
__rlshift__ = partialmethod(_rev_binary_op, operator.lshift)
__rshift__ = partialmethod(_binary_op, operator.rshift)
__rrshift__ = partialmethod(_rev_binary_op, operator.rshift)
__and__ = partialmethod(_binary_op, operator.and_)
__rand__ = partialmethod(_rev_binary_op, operator.and_)
__or__ = partialmethod(_binary_op, operator.or_)
__ror__ = partialmethod(_rev_binary_op, operator.or_)
# }}}
# vim: foldmethod=marker
# mypy: disallow-untyped-defs
"""
.. currentmodule:: arraycontext
.. autofunction:: dataclass_array_container
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from collections.abc import Mapping, Sequence
from dataclasses import fields, is_dataclass
from typing import NamedTuple, Union, get_args, get_origin
from arraycontext.container import is_array_container_type
# {{{ dataclass containers
class _Field(NamedTuple):
"""Small lookalike for :class:`dataclasses.Field`."""
init: bool
name: str
type: type
def is_array_type(tp: type) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)
def dataclass_array_container(cls: type) -> type:
"""A class decorator that makes the class to which it is applied an
:class:`ArrayContainer` by registering appropriate implementations of
:func:`serialize_container` and :func:`deserialize_container`.
*cls* must be a :func:`~dataclasses.dataclass`.
Attributes that are not array containers are allowed. In order to decide
whether an attribute is an array container, the declared attribute type
is checked by the criteria from :func:`is_array_container_type`. This
includes some support for type annotations:
* a :class:`typing.Union` of array containers is considered an array container.
* other type annotations, e.g. :class:`typing.Optional`, are not considered
array containers, even if they wrap one.
.. note::
When type annotations are strings (e.g. because of
``from __future__ import annotations``),
this function relies on :func:`inspect.get_annotations`
(with ``eval_str=True``) to obtain type annotations. This
means that *cls* must live in a module that is importable.
"""
from types import GenericAlias, UnionType
assert is_dataclass(cls)
def is_array_field(f: _Field) -> bool:
field_type = f.type
# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
#
# `Optional[ArrayContainer]` is not allowed, since `None` is not
# handled by `with_container_arithmetic`, which is the common case
# for current container usage. Other type annotations, e.g.
# `Tuple[Container, Container]`, are also not allowed, as they do not
# work with `with_container_arithmetic`.
#
# This is not set in stone, but mostly driven by current usage!
origin = get_origin(field_type)
# NOTE: `UnionType` is returned when using `Type1 | Type2`
if origin in (Union, UnionType):
if all(is_array_type(arg) for arg in get_args(field_type)):
return True
else:
raise TypeError(
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")
# NOTE: this should never happen due to using `inspect.get_annotations`
assert not isinstance(field_type, str)
if __debug__:
if not f.init:
raise ValueError(
f"Field with 'init=False' not allowed: '{f.name}'")
# NOTE:
# * `GenericAlias` catches typed `list`, `tuple`, etc.
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
# * `_SpecialForm` catches `Any`, `Literal`, etc.
from typing import ( # type: ignore[attr-defined]
_BaseGenericAlias,
_SpecialForm,
)
if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"Typing annotation not supported on field '{f.name}': "
f"'{field_type!r}'")
if not isinstance(field_type, type):
raise TypeError(
f"Field '{f.name}' not an instance of 'type': "
f"'{field_type!r}'")
return is_array_type(field_type)
from pytools import partition
array_fields = _get_annotated_fields(cls)
array_fields, non_array_fields = partition(is_array_field, array_fields)
if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")
return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
def _get_annotated_fields(cls: type) -> Sequence[_Field]:
"""Get a list of fields in the class *cls* with evaluated types.
If any of the fields in *cls* have type annotations that are strings, e.g.
from using ``from __future__ import annotations``, this function evaluates
them using :func:`inspect.get_annotations`. Note that this requires the class
to live in a module that is importable.
:return: a list of fields.
"""
from inspect import get_annotations
result = []
cls_ann: Mapping[str, type] | None = None
for field in fields(cls):
field_type_or_str = field.type
if isinstance(field_type_or_str, str):
if cls_ann is None:
cls_ann = get_annotations(cls, eval_str=True)
field_type = cls_ann[field.name]
else:
field_type = field_type_or_str
result.append(_Field(init=field.init, name=field.name, type=field_type))
return result
def _inject_dataclass_serialization(
cls: type,
array_fields: Sequence[_Field],
non_array_fields: Sequence[_Field]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
This function modifies *cls* in place, so the returned value is the same
object with additional functionality.
:arg array_fields: fields of the given dataclass *cls* which are considered
array containers and should be serialized.
:arg non_array_fields: remaining fields of the dataclass *cls* which are
copied over from the template array in deserialization.
"""
assert is_dataclass(cls)
serialize_expr = ", ".join(
f"({f.name!r}, ary.{f.name})" for f in array_fields)
template_kwargs = ", ".join(
f"{f.name}=template.{f.name}" for f in non_array_fields)
lower_cls_name = cls.__name__.lower()
serialize_init_code = ", ".join(f"{f.name!r}: f'{{instance_name}}.{f.name}'"
for f in array_fields)
deserialize_init_code = ", ".join([
f"{f.name}={{args[{f.name!r}]}}" for f in array_fields
] + [
f"{f.name}={{template_instance_name}}.{f.name}"
for f in non_array_fields
])
from pytools.codegen import remove_common_indentation
serialize_code = remove_common_indentation(f"""
from typing import Any, Iterable, Tuple
from arraycontext import serialize_container, deserialize_container
@serialize_container.register(cls)
def _serialize_{lower_cls_name}(ary: cls) -> Iterable[Tuple[Any, Any]]:
return ({serialize_expr},)
@deserialize_container.register(cls)
def _deserialize_{lower_cls_name}(
template: cls, iterable: Iterable[Tuple[Any, Any]]) -> cls:
return cls(**dict(iterable), {template_kwargs})
# support for with_container_arithmetic
def _serialize_init_arrays_code_{lower_cls_name}(cls, instance_name):
return {{
{serialize_init_code}
}}
cls._serialize_init_arrays_code = classmethod(
_serialize_init_arrays_code_{lower_cls_name})
def _deserialize_init_arrays_code_{lower_cls_name}(
cls, template_instance_name, args):
return f"{deserialize_init_code}"
cls._deserialize_init_arrays_code = classmethod(
_deserialize_init_arrays_code_{lower_cls_name})
""")
exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": serialize_code}
exec(compile(serialize_code, f"<container serialization for {cls.__name__}>",
"exec"), exec_dict)
return cls
# }}}
# vim: foldmethod=marker
# mypy: disallow-untyped-defs
"""
.. currentmodule:: arraycontext
.. autofunction:: map_array_container
.. autofunction:: multimap_array_container
.. autofunction:: rec_map_array_container
.. autofunction:: rec_multimap_array_container
.. autofunction:: map_reduce_array_container
.. autofunction:: multimap_reduce_array_container
.. autofunction:: rec_map_reduce_array_container
.. autofunction:: rec_multimap_reduce_array_container
.. autofunction:: stringify_array_container_tree
Traversing decorators
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: mapped_over_array_containers
.. autofunction:: multimapped_over_array_containers
Freezing and thawing
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: freeze
.. autofunction:: thaw
Flattening and unflattening
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: flatten
.. autofunction:: unflatten
.. autofunction:: flat_size_and_dtype
Numpy conversion
~~~~~~~~~~~~~~~~
.. autofunction:: from_numpy
.. autofunction:: to_numpy
Algebraic operations
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: outer
"""
from __future__ import annotations
from arraycontext.container.arithmetic import NumpyObjectArray
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from collections.abc import Callable, Iterable
from functools import partial, singledispatch, update_wrapper
from typing import Any, cast
from warnings import warn
import numpy as np
from arraycontext.container import (
ArrayContainer,
NotAnArrayContainerError,
SerializationKey,
deserialize_container,
get_container_context_recursively_opt,
serialize_container,
)
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerT,
ArrayT,
ScalarLike,
)
# {{{ array container traversal helpers
def _map_array_container_impl(
f: Callable[[ArrayOrContainer], ArrayOrContainer],
ary: ArrayOrContainer, *,
leaf_cls: type | None = None,
recursive: bool = False) -> ArrayOrContainer:
"""Helper for :func:`rec_map_array_container`.
:param leaf_cls: class on which we call *f* directly. This is mostly
useful in the recursive setting, where it can stop the recursion on
specific container classes. By default, the recursion is stopped when
a non-:class:`ArrayContainer` class is encountered.
"""
def rec(ary_: ArrayOrContainer) -> ArrayOrContainer:
if type(ary_) is leaf_cls: # type(ary) is never None
return f(ary_)
try:
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return f(ary_)
else:
return deserialize_container(ary_, [
(key, frec(subary)) for key, subary in iterable
])
frec = rec if recursive else f
return rec(ary)
def _multimap_array_container_impl(
f: Callable[..., Any],
*args: Any,
reduce_func: (
Callable[[ArrayContainer, Iterable[tuple[Any, Any]]], Any] | None) = None,
leaf_cls: type | None = None,
recursive: bool = False) -> ArrayOrContainer:
"""Helper for :func:`rec_multimap_array_container`.
:param leaf_cls: class on which we call *f* directly. This is mostly
useful in the recursive setting, where it can stop the recursion on
specific container classes. By default, the recursion is stopped when
a non-:class:`ArrayContainer` class is encountered.
"""
# {{{ recursive traversal
def rec(*args_: Any) -> Any:
template_ary = args_[container_indices[0]]
if type(template_ary) is leaf_cls:
return f(*args_)
try:
iterable_template = serialize_container(template_ary)
except NotAnArrayContainerError:
return f(*args_)
assert all(
type(args_[i]) is type(template_ary) for i in container_indices[1:]
), f"expected type '{type(template_ary).__name__}'"
result = []
new_args = list(args_)
for subarys in zip(
iterable_template,
*[serialize_container(args_[i]) for i in container_indices[1:]],
strict=True
):
key = None
for i, (subkey, subary) in zip(container_indices, subarys, strict=True):
if key is None:
key = subkey
else:
assert key == subkey
new_args[i] = subary
result.append((key, frec(*new_args)))
return process_container(template_ary, result)
# }}}
# {{{ find all containers in the argument list
container_indices: list[int] = []
for i, arg in enumerate(args):
if type(arg) is leaf_cls:
continue
try:
# FIXME: this will serialize again once `rec` is called, which is
# not great, but it doesn't seem like there's a good way to avoid it
_ = serialize_container(arg)
except NotAnArrayContainerError:
pass
else:
container_indices.append(i)
# }}}
# {{{ #containers == 0 => call `f`
if not container_indices:
return f(*args)
# }}}
# {{{ #containers == 1 => call `map_array_container`
if len(container_indices) == 1 and reduce_func is None:
# NOTE: if we just have one ArrayContainer in args, passing it through
# _map_array_container_impl should be faster
def wrapper(ary: ArrayOrContainerT) -> ArrayOrContainerT:
new_args = list(args)
new_args[container_indices[0]] = ary
return f(*new_args)
update_wrapper(wrapper, f)
template_ary: ArrayContainer = args[container_indices[0]]
return _map_array_container_impl(
wrapper, template_ary,
leaf_cls=leaf_cls, recursive=recursive)
# }}}
# {{{ #containers > 1 => call `rec`
process_container = deserialize_container if reduce_func is None else reduce_func
frec = rec if recursive else f
# }}}
return rec(*args)
# }}}
# {{{ array container traversal
def stringify_array_container_tree(ary: ArrayOrContainer) -> str:
"""
:returns: a string for an ASCII tree representation of the array container,
similar to `asciitree <https://github.com/mbr/asciitree>`__.
"""
def rec(lines: list[str], ary_: ArrayOrContainerT, level: int) -> None:
try:
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
pass
else:
for key, subary in iterable:
key = f"{key} ({type(subary).__name__})"
indent = "" if level == 0 else f" | {' ' * 4 * (level - 1)}"
lines.append(f"{indent} +-- {key}")
rec(lines, subary, level + 1)
lines = [f"root ({type(ary).__name__})"]
rec(lines, ary, 0)
return "\n".join(lines)
def map_array_container(
f: Callable[[Any], Any],
ary: ArrayOrContainer) -> ArrayOrContainer:
r"""Applies *f* to all components of an :class:`ArrayContainer`.
Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but
on arbitrary containers.
For a recursive version, see :func:`rec_map_array_container`.
:param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
or an instance of a base array type.
"""
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError:
return f(ary)
else:
return deserialize_container(ary, [
(key, f(subary)) for key, subary in iterable
])
def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
r"""Applies *f* to the components of multiple :class:`ArrayContainer`\ s.
Works similarly to :func:`~pytools.obj_array.obj_array_vectorize_n_args`,
but on arbitrary containers. The containers must all have the same type,
which will also be the return type.
For a recursive version, see :func:`rec_multimap_array_container`.
:param args: all :class:`ArrayContainer` arguments must be of the same
type and with the same structure (same number of components, etc.).
"""
return _multimap_array_container_impl(f, *args, recursive=False)
def rec_map_array_container(
f: Callable[[Any], Any],
ary: ArrayOrContainer,
leaf_class: type | None = None) -> ArrayOrContainer:
r"""Applies *f* recursively to an :class:`ArrayContainer`.
For a non-recursive version see :func:`map_array_container`.
:param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
or an instance of a base array type.
"""
return _map_array_container_impl(f, ary, leaf_cls=leaf_class, recursive=True)
def mapped_over_array_containers(
f: Callable[[ArrayOrContainer], ArrayOrContainer] | None = None,
leaf_class: type | None = None) -> (
Callable[[ArrayOrContainer], ArrayOrContainer]
| Callable[
[Callable[[Any], Any]],
Callable[[ArrayOrContainer], ArrayOrContainer]]):
"""Decorator around :func:`rec_map_array_container`."""
def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[
[ArrayOrContainer], ArrayOrContainer]:
wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class)
update_wrapper(wrapper, g)
return wrapper
if f is not None:
return decorator(f)
else:
return decorator
def rec_multimap_array_container(
f: Callable[..., Any],
*args: Any,
leaf_class: type | None = None) -> Any:
r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
For a non-recursive version see :func:`multimap_array_container`.
:param args: all :class:`ArrayContainer` arguments must be of the same
type and with the same structure (same number of components, etc.).
"""
return _multimap_array_container_impl(
f, *args, leaf_cls=leaf_class, recursive=True)
def multimapped_over_array_containers(
f: Callable[..., Any] | None = None,
leaf_class: type | None = None) -> (
Callable[..., Any]
| Callable[[Callable[..., Any]], Callable[..., Any]]):
"""Decorator around :func:`rec_multimap_array_container`."""
def decorator(g: Callable[..., Any]) -> Callable[..., Any]:
# can't use functools.partial, because its result is insufficiently
# function-y to be used as a method definition.
def wrapper(*args: Any) -> Any:
return rec_multimap_array_container(g, *args, leaf_class=leaf_class)
update_wrapper(wrapper, g)
return wrapper
if f is not None:
return decorator(f)
else:
return decorator
# }}}
# {{{ keyed array container traversal
def keyed_map_array_container(
f: Callable[
[SerializationKey, ArrayOrContainer],
ArrayOrContainer],
ary: ArrayOrContainer) -> ArrayOrContainer:
r"""Applies *f* to all components of an :class:`ArrayContainer`.
Works similarly to :func:`map_array_container`, but *f* also takes an
identifier of the array in the container *ary*.
For a recursive version, see :func:`rec_keyed_map_array_container`.
:param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
or an instance of a base array type.
"""
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError as err:
raise ValueError(
f"Non-array container type has no key: {type(ary).__name__}") from err
else:
return deserialize_container(ary, [
(key, f(key, subary)) for key, subary in iterable
])
def rec_keyed_map_array_container(
f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
ary: ArrayOrContainer) -> ArrayOrContainer:
"""
Works similarly to :func:`rec_map_array_container`, except that *f* also
takes in a traversal path to the leaf array. The traversal path argument is
passed in as a tuple of identifiers of the arrays traversed before reaching
the current array.
"""
def rec(keys: tuple[SerializationKey, ...],
ary_: ArrayOrContainerT) -> ArrayOrContainerT:
try:
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_)))
else:
return deserialize_container(ary_, [
(key, rec((*keys, key), subary)) for key, subary in iterable
])
return rec((), ary)
# }}}
# {{{ array container reductions
def map_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[[Any], Any],
ary: ArrayOrContainerT) -> Array:
"""Perform a map-reduce over array containers.
:param reduce_func: callable used to reduce over the components of *ary*
if *ary* is an :class:`~arraycontext.ArrayContainer`. The callable
should be associative, as for :func:`rec_map_reduce_array_container`.
:param map_func: callable used to map a single array of type
:class:`arraycontext.ArrayContext.array_types`. Returns an array of the
same type or a scalar.
"""
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError:
return map_func(ary)
else:
return reduce_func([
map_func(subary) for _, subary in iterable
])
def multimap_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[..., Any],
*args: Any) -> ArrayOrContainer:
r"""Perform a map-reduce over multiple array containers.
:param reduce_func: callable used to reduce over the components of any
:class:`~arraycontext.ArrayContainer`\ s in *\*args*. The callable
should be associative, as for :func:`rec_map_reduce_array_container`.
:param map_func: callable used to map a single array of type
:class:`arraycontext.ArrayContext.array_types`. Returns an array of the
same type or a scalar.
"""
# NOTE: this wrapper matches the signature of `deserialize_container`
# to make plugging into `_multimap_array_container_impl` easier
def _reduce_wrapper(
ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]]
) -> Array:
return reduce_func([subary for _, subary in iterable])
return _multimap_array_container_impl(
map_func, *args,
reduce_func=_reduce_wrapper, leaf_cls=None, recursive=False)
def rec_map_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[[Any], Any],
ary: ArrayOrContainer,
leaf_class: type | None = None) -> ArrayOrContainer:
"""Perform a map-reduce over array containers recursively.
:param reduce_func: callable used to reduce over the components of *ary*
(and those of its sub-containers) if *ary* is a
:class:`~arraycontext.ArrayContainer`. Must be associative.
:param map_func: callable used to map a single array of type
:class:`arraycontext.ArrayContext.array_types`. Returns an array of the
same type or a scalar.
.. note::
The traversal order is unspecified. *reduce_func* must be associative in
order to guarantee a sensible result. This is because *reduce_func* may be
called on subsets of the component arrays, and then again (potentially
multiple times) on the results. As an example, consider a container made up
of two sub-containers, *subcontainer0* and *subcontainer1*, that each
contain two component arrays, *array0* and *array1*. The same result must be
computed whether traversing recursively::
reduce_func([
reduce_func([
map_func(subcontainer0.array0),
map_func(subcontainer0.array1)]),
reduce_func([
map_func(subcontainer1.array0),
map_func(subcontainer1.array1)])])
reducing all of the arrays at once::
reduce_func([
map_func(subcontainer0.array0),
map_func(subcontainer0.array1),
map_func(subcontainer1.array0),
map_func(subcontainer1.array1)])
or any other such traversal.
"""
def rec(ary_: ArrayOrContainerT) -> ArrayOrContainerT:
if type(ary_) is leaf_class:
return map_func(ary_)
else:
try:
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return map_func(ary_)
else:
return reduce_func([
rec(subary) for _, subary in iterable
])
return rec(ary)
def rec_multimap_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[..., Any],
*args: Any,
leaf_class: type | None = None) -> ArrayOrContainer:
r"""Perform a map-reduce over multiple array containers recursively.
:param reduce_func: callable used to reduce over the components of any
:class:`~arraycontext.ArrayContainer`\ s in *\*args* (and those of their
sub-containers). Must be associative.
:param map_func: callable used to map a single array of type
:class:`arraycontext.ArrayContext.array_types`. Returns an array of the
same type or a scalar.
.. note::
The traversal order is unspecified. *reduce_func* must be associative in
order to guarantee a sensible result. See
:func:`rec_map_reduce_array_container` for additional details.
"""
# NOTE: this wrapper matches the signature of `deserialize_container`
# to make plugging into `_multimap_array_container_impl` easier
def _reduce_wrapper(
ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]]) -> Any:
return reduce_func([subary for _, subary in iterable])
return _multimap_array_container_impl(
map_func, *args,
reduce_func=_reduce_wrapper, leaf_cls=leaf_class, recursive=True)
# }}}
# {{{ freeze/thaw
def freeze(
ary: ArrayOrContainerT,
actx: ArrayContext | None = None) -> ArrayOrContainerT:
r"""Freezes recursively by going through all components of the
:class:`ArrayContainer` *ary*.
:param ary: a :meth:`~ArrayContext.thaw`\ ed :class:`ArrayContainer`.
Array container types may use :func:`functools.singledispatch` ``.register`` to
register additional implementations.
See :meth:`ArrayContext.thaw`.
"""
if actx is None:
warn("Calling freeze(ary) without specifying actx is deprecated, explicitly"
" call actx.freeze(ary) instead. This will stop working in 2023.",
DeprecationWarning, stacklevel=2)
actx = get_container_context_recursively_opt(ary)
else:
warn("Calling freeze(ary, actx) is deprecated, call actx.freeze(ary)"
" instead. This will stop working in 2023.",
DeprecationWarning, stacklevel=2)
if __debug__:
rec_actx = get_container_context_recursively_opt(ary)
if (rec_actx is not None) and (rec_actx is not actx):
raise ValueError("Supplied array context does not agree with"
" the one obtained by traversing 'ary'.")
if actx is None:
raise TypeError(
f"cannot freeze arrays of type {type(ary).__name__} "
"when actx is not supplied. Try calling actx.freeze "
"directly or supplying an array context")
return actx.freeze(ary)
def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
r"""Thaws recursively by going through all components of the
:class:`ArrayContainer` *ary*.
:param ary: a :meth:`~ArrayContext.freeze`\ ed :class:`ArrayContainer`.
Array container types may use :func:`functools.singledispatch` ``.register``
to register additional implementations.
See :meth:`ArrayContext.thaw`.
Serves as the registration point (using :func:`~functools.singledispatch`
``.register`` to register additional implementations for :func:`thaw`.
.. note::
This function has the reverse argument order from the original function
in :mod:`meshmode`. This was necessary because
:func:`~functools.singledispatch` only dispatches on the first argument.
"""
warn("Calling thaw(ary, actx) is deprecated, call actx.thaw(ary) instead."
" This will stop working in 2023.",
DeprecationWarning, stacklevel=2)
if __debug__:
rec_actx = get_container_context_recursively_opt(ary)
if rec_actx is not None:
raise ValueError("cannot thaw a container that already has an array"
" context.")
return actx.thaw(ary)
# }}}
# {{{ with_array_context
@singledispatch
def with_array_context(ary: ArrayOrContainerT,
actx: ArrayContext | None) -> ArrayOrContainerT:
"""
Recursively associates *actx* to all the components of *ary*.
Array container types may use :func:`functools.singledispatch` ``.register``
to register container-specific implementations. See `this issue
<https://github.com/inducer/arraycontext/issues/162>`__ for discussion of
the future of this functionality.
"""
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError:
return ary
else:
return deserialize_container(ary, [(key, with_array_context(subary, actx))
for key, subary in iterable])
# }}}
# {{{ flatten / unflatten
def flatten(
ary: ArrayOrContainer, actx: ArrayContext, *,
leaf_class: type | None = None,
) -> Any:
"""Convert all arrays in the :class:`~arraycontext.ArrayContainer`
into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`.
The operation requires :attr:`arraycontext.ArrayContext.np` to have
``ravel`` and ``concatenate`` methods implemented. The order in which the
individual leaf arrays appear in the final array is dependent on the order
given by :func:`~arraycontext.serialize_container`.
If *leaf_class* is given, then :func:`unflatten` will not be able to recover
the original *ary*.
:arg leaf_class: an :class:`~arraycontext.ArrayContainer` class on which
the recursion is stopped (subclasses are not considered). If given, only
the entries of this type are flattened and the rest of the tree
structure is left as is. By default, the recursion is stopped when
a non-:class:`~arraycontext.ArrayContainer` is found, which results in
the whole input container *ary* being flattened.
"""
common_dtype = None
def _flatten(subary: ArrayOrContainer) -> list[Array]:
nonlocal common_dtype
try:
iterable = serialize_container(subary)
except NotAnArrayContainerError:
subary_c = cast(Array, subary)
if common_dtype is None:
common_dtype = subary_c.dtype
if subary_c.dtype != common_dtype:
raise ValueError("arrays in container have different dtypes: "
f"got {subary_c.dtype}, expected {common_dtype}") from None
try:
flat_subary = actx.np.ravel(subary_c, order="C")
except ValueError as exc:
# NOTE: we can't do much if the array context fails to ravel,
# since it is the one responsible for the actual memory layout
if hasattr(subary_c, "strides"):
strides_msg = f" and strides {subary_c.strides}"
else:
strides_msg = ""
raise NotImplementedError(
f"'{type(actx).__name__}.np.ravel' failed to reshape "
f"an array with shape {subary_c.shape}{strides_msg}. "
"This functionality needs to be implemented by the "
"array context.") from exc
result = [flat_subary]
else:
result = []
for _, isubary in iterable:
result.extend(_flatten(isubary))
return result
def _flatten_without_leaf_class(subary: ArrayOrContainer) -> Any:
result = _flatten(subary)
if len(result) == 1:
return result[0]
else:
return actx.np.concatenate(result)
def _flatten_with_leaf_class(subary: ArrayOrContainer) -> Any:
if type(subary) is leaf_class:
return _flatten_without_leaf_class(subary)
try:
iterable = serialize_container(subary)
except NotAnArrayContainerError:
return subary
else:
return deserialize_container(subary, [
(key, _flatten_with_leaf_class(isubary))
for key, isubary in iterable
])
if leaf_class is None:
return _flatten_without_leaf_class(ary)
else:
return _flatten_with_leaf_class(ary)
def unflatten(
template: ArrayOrContainerT, ary: Array,
actx: ArrayContext, *,
strict: bool = True) -> ArrayOrContainerT:
"""Unflatten an array *ary* produced by :func:`flatten` back into an
:class:`~arraycontext.ArrayContainer`.
The order and sizes of each slice into *ary* are determined by the
array container *template*.
:arg ary: a flat one-dimensional array with a size that matches the
number of entries in *template*.
:arg strict: if *True* additional :class:`~numpy.dtype` and stride
checking is performed on the unflattened array. Otherwise, these
checks are skipped.
"""
# NOTE: https://github.com/python/mypy/issues/7057
offset = 0
common_dtype = None
def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
nonlocal offset, common_dtype
try:
iterable = serialize_container(template_subary)
except NotAnArrayContainerError:
template_subary_c = cast(Array, template_subary)
# {{{ validate subary
if (offset + template_subary_c.size) > ary.size:
raise ValueError("'template' and 'ary' sizes do not match: "
"'template' is too large") from None
if strict:
if template_subary_c.dtype != ary.dtype:
raise ValueError("'template' dtype does not match 'ary': "
f"got {template_subary_c.dtype}, expected {ary.dtype}"
) from None
else:
# NOTE: still require that *template* has a uniform dtype
if common_dtype is None:
common_dtype = template_subary_c.dtype
else:
if common_dtype != template_subary_c.dtype:
raise ValueError("arrays in 'template' have different "
f"dtypes: got {template_subary_c.dtype}, but "
f"expected {common_dtype}.") from None
# }}}
# {{{ reshape
flat_subary = ary[offset:offset + template_subary_c.size]
try:
subary = actx.np.reshape(flat_subary,
template_subary_c.shape, order="C")
except ValueError as exc:
# NOTE: we can't do much if the array context fails to reshape,
# since it is the one responsible for the actual memory layout
raise NotImplementedError(
f"'{type(actx).__name__}.np.reshape' failed to reshape "
f"the flat array into shape {template_subary_c.shape}. "
"This functionality needs to be implemented by the "
"array context.") from exc
# }}}
# {{{ check strides
if strict and hasattr(template_subary_c, "strides"): # noqa: SIM102
# Checking strides for 0 sized arrays is ill-defined
# since they cannot be indexed
if (
# Mypy has a point: nobody promised a .strides attribute.
template_subary_c.strides != subary.strides
and template_subary_c.size != 0
):
raise ValueError(
# Mypy has a point: nobody promised a .strides attribute.
f"strides do not match template: got {subary.strides}, "
f"expected {template_subary_c.strides}") from None
# }}}
offset += template_subary_c.size
return subary
else:
return deserialize_container(template_subary, [
(key, _unflatten(isubary)) for key, isubary in iterable
])
if not isinstance(ary, actx.array_types):
raise TypeError("'ary' does not have a type supported by the provided "
f"array context: got '{type(ary).__name__}', expected one of "
f"{actx.array_types}")
if len(ary.shape) != 1:
raise ValueError(
"only one dimensional arrays can be unflattened: "
f"'ary' has shape {ary.shape}")
result = _unflatten(template)
if offset != ary.size:
raise ValueError("'template' and 'ary' sizes do not match: "
"'ary' is too large")
return cast(ArrayOrContainerT, result)
def flat_size_and_dtype(
ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]:
"""
:returns: a tuple ``(size, dtype)`` that would be the length and
:class:`numpy.dtype` of the one-dimensional array returned by
:func:`flatten`.
"""
common_dtype = None
def _flat_size(subary: ArrayOrContainer) -> int:
nonlocal common_dtype
try:
iterable = serialize_container(subary)
except NotAnArrayContainerError:
subary_c = cast(Array, subary)
if common_dtype is None:
common_dtype = subary_c.dtype
if subary_c.dtype != common_dtype:
raise ValueError("arrays in container have different dtypes: "
f"got {subary_c.dtype}, expected {common_dtype}") from None
return subary_c.size
else:
return sum(_flat_size(isubary) for _, isubary in iterable)
size = _flat_size(ary)
return size, common_dtype
# }}}
# {{{ numpy conversion
def from_numpy(
ary: np.ndarray | ScalarLike,
actx: ArrayContext) -> ArrayOrContainerOrScalar:
"""Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer`
to the base array type of :class:`~arraycontext.ArrayContext`.
The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`.
"""
warn("Calling from_numpy(ary, actx) is deprecated, call actx.from_numpy(ary)"
" instead. This will stop working in 2023.",
DeprecationWarning, stacklevel=2)
return actx.from_numpy(ary)
def to_numpy(ary: ArrayOrContainer, actx: ArrayContext) -> ArrayOrContainer:
"""Convert all arrays in the :class:`~arraycontext.ArrayContainer` to
:mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*.
The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`.
"""
warn("Calling to_numpy(ary, actx) is deprecated, call actx.to_numpy(ary)"
" instead. This will stop working in 2023.",
DeprecationWarning, stacklevel=2)
return actx.to_numpy(ary)
# }}}
# {{{ algebraic operations
def outer(a: Any, b: Any) -> Any:
"""
Compute the outer product of *a* and *b* while allowing either of them
to be an :class:`ArrayContainer`.
Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional
object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer`
always returns a matrix). Here the definition of "scalar" includes
all non-array-container types and any scalar-like array container types.
If *a* and *b* are both array containers, the result will have the same type
as *a*. If both are array containers and neither is an object array, they must
have the same type.
"""
def treat_as_scalar(x: Any) -> bool:
try:
serialize_container(x)
except NotAnArrayContainerError:
return True
else:
return (
not isinstance(x, np.ndarray)
# This condition is whether "ndarrays should broadcast inside x".
and NumpyObjectArray not in x.__class__._outer_bcast_types)
a_is_ndarray = isinstance(a, np.ndarray)
b_is_ndarray = isinstance(b, np.ndarray)
if a_is_ndarray and a.dtype != object:
raise TypeError("passing a non-object numpy array is not allowed")
if b_is_ndarray and b.dtype != object:
raise TypeError("passing a non-object numpy array is not allowed")
if treat_as_scalar(a) or treat_as_scalar(b):
return a*b
elif a_is_ndarray and b_is_ndarray:
return np.outer(a, b)
elif a_is_ndarray or b_is_ndarray:
return map_array_container(lambda x: outer(x, b), a)
else:
if type(a) is not type(b):
raise TypeError(
"both arguments must have the same type if they are both "
"non-object-array array containers.")
return multimap_array_container(lambda x, y: outer(x, y), a, b)
# }}}
# vim: foldmethod=marker
# mypy: disallow-untyped-defs
"""
.. _freeze-thaw:
Freezing and thawing
--------------------
One of the central concepts introduced by the array context formalism is
the notion of :meth:`~arraycontext.ArrayContext.freeze` and
:meth:`~arraycontext.ArrayContext.thaw`. Each array handled by the array context
is either "thawed" or "frozen". Unlike the real-world concept of freezing and
thawing, these operations leave the original array alone; instead, a semantically
separate array in the desired state is returned.
* "Thawed" arrays are associated with an array context. They use that context
to carry out operations (arithmetic, function calls).
* "Frozen" arrays are static data. They are not associated with an array context,
and no operations can be performed on them.
Freezing and thawing may be used to move arrays from one array context to another,
as long as both array contexts use identical in-memory data representation.
Otherwise, a common format must be agreed upon, for example using
:mod:`numpy` through :meth:`~arraycontext.ArrayContext.to_numpy` and
:meth:`~arraycontext.ArrayContext.from_numpy`.
.. _freeze-thaw-guidelines:
Usage guidelines
^^^^^^^^^^^^^^^^
Here are some rules of thumb to use when dealing with thawing and freezing:
- Any array that is stored for a long time needs to be frozen.
"Memoized" data (cf. :func:`pytools.memoize` and friends) is a good example
of long-lived data that should be frozen.
- Within a function, if the user did not supply an array context,
then any data returned to the user should be frozen.
- Note that array contexts need not necessarily be passed as a separate
argument. Passing thawed data as an argument to a function suffices
to supply an array context. The array context can be extracted from
a thawed argument using, e.g., :func:`~arraycontext.get_container_context_opt`
or :func:`~arraycontext.get_container_context_recursively`.
What does this mean concretely?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Freezing and thawing are abstract names for concrete operations. It may be helpful
to understand what these operations mean in the concrete case of various
actual array contexts:
- Each :class:`~arraycontext.PyOpenCLArrayContext` is associated with a
:class:`pyopencl.CommandQueue`. In order to operate on array data,
such a command queue is necessary; it is the main means of synchronization
between the host program and the compute device. "Thawing" here
means associating an array with a command queue, and "freezing" means
ensuring that the array data is fully computed in memory and
decoupling the array from the command queue. It is not valid to "mix"
arrays associated with multiple queues within an operation: if it were allowed,
a dependent operation might begin computing before an input is fully
available. (Since bugs of this nature would be very difficult to
find, :class:`pyopencl.array.Array` and
:class:`~meshmode.dof_array.DOFArray` will not allow them.)
- For the lazily-evaluating array context based on :mod:`pytato`,
"thawing" corresponds to the creation of a symbolic "handle"
(specifically, a :class:`pytato.array.DataWrapper`) representing
the array that can then be used in computation, and "freezing"
corresponds to triggering (code generation and) evaluation of
an array expression that has been built up by the user
(using, e.g. :func:`pytato.generate_loopy`).
.. currentmodule:: arraycontext
The :class:`ArrayContext` Interface
-----------------------------------
.. autoclass:: ArrayContext
.. autofunction:: tag_axes
Types and Type Variables for Arrays and Containers
--------------------------------------------------
.. autoclass:: Array
.. autodata:: ArrayT
A type variable with a lower bound of :class:`Array`.
.. autodata:: ScalarLike
A type annotation for scalar types commonly usable with arrays.
See also :class:`ArrayContainer` and :class:`ArrayOrContainerT`.
.. autodata:: ArrayOrContainer
.. autodata:: ArrayOrContainerT
A type variable with a bound of :class:`ArrayOrContainer`.
.. autodata:: ArrayOrArithContainer
.. autodata:: ArrayOrArithContainerT
A type variable with a bound of :class:`ArrayOrArithContainer`.
.. autodata:: ArrayOrArithContainerOrScalar
.. autodata:: ArrayOrArithContainerOrScalarT
A type variable with a bound of :class:`ArrayOrContainerOrScalar`.
.. autodata:: ArrayOrContainerOrScalar
.. autodata:: ArrayOrContainerOrScalarT
A type variable with a bound of :class:`ArrayOrContainerOrScalar`.
.. currentmodule:: arraycontext.context
Canonical locations for type annotations
----------------------------------------
.. class:: ArrayT
:canonical: arraycontext.ArrayT
.. class:: ArrayOrContainerT
:canonical: arraycontext.ArrayOrContainerT
.. class:: ArrayOrContainerOrScalarT
:canonical: arraycontext.ArrayOrContainerOrScalarT
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload
from warnings import warn
import numpy as np
from typing_extensions import Self
from pytools import memoize_method
from pytools.tag import ToTagSetConvertible
if TYPE_CHECKING:
import loopy
from arraycontext.container import ArithArrayContainer, ArrayContainer
# {{{ typing
class Array(Protocol):
"""A :class:`~typing.Protocol` for the array type supported by
:class:`ArrayContext`.
This is meant to aid in typing annotations. For a explicit list of
supported types see :attr:`ArrayContext.array_types`.
.. attribute:: shape
.. attribute:: size
.. attribute:: dtype
.. attribute:: __getitem__
In addition, arrays are expected to support basic arithmetic.
"""
@property
def shape(self) -> tuple[int, ...]:
...
@property
def size(self) -> int:
...
@property
def dtype(self) -> np.dtype[Any]:
...
# Covering all the possible index variations is hard and (kind of) futile.
# If you'd like to see how, try changing the Any to
# AxisIndex = slice | int | "Array"
# Index = AxisIndex |tuple[AxisIndex]
def __getitem__(self, index: Any) -> Array:
...
# some basic arithmetic that's supposed to work
def __neg__(self) -> Self: ...
def __abs__(self) -> Self: ...
def __add__(self, other: Self | ScalarLike) -> Self: ...
def __radd__(self, other: Self | ScalarLike) -> Self: ...
def __sub__(self, other: Self | ScalarLike) -> Self: ...
def __rsub__(self, other: Self | ScalarLike) -> Self: ...
def __mul__(self, other: Self | ScalarLike) -> Self: ...
def __rmul__(self, other: Self | ScalarLike) -> Self: ...
def __truediv__(self, other: Self | ScalarLike) -> Self: ...
def __rtruediv__(self, other: Self | ScalarLike) -> Self: ...
# deprecated, use ScalarLike instead
ScalarLike: TypeAlias = int | float | complex | np.generic
Scalar = ScalarLike
ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike)
# NOTE: I'm kind of not sure about the *Tc versions of these type variables.
# mypy seems better at understanding arithmetic performed on the *Tc versions
# than the *T, versions, whereas pyright doesn't seem to care.
#
# This issue seems to be part of it:
# https://github.com/python/mypy/issues/18203
# but there is likely other stuff lurking.
#
# For now, they're purposefully not in the main arraycontext.* name space.
ArrayT = TypeVar("ArrayT", bound=Array)
ArrayOrScalar: TypeAlias = "Array | ScalarLike"
ArrayOrContainer: TypeAlias = "Array | ArrayContainer"
ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer"
ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer)
ArrayOrContainerTc = TypeVar("ArrayOrContainerTc",
Array, "ArrayContainer", "ArithArrayContainer")
ArrayOrArithContainerT = TypeVar("ArrayOrArithContainerT", bound=ArrayOrArithContainer)
ArrayOrArithContainerTc = TypeVar("ArrayOrArithContainerTc",
Array, "ArithArrayContainer")
ArrayOrContainerOrScalar: TypeAlias = "Array | ArrayContainer | ScalarLike"
ArrayOrArithContainerOrScalar: TypeAlias = "Array | ArithArrayContainer | ScalarLike"
ArrayOrContainerOrScalarT = TypeVar(
"ArrayOrContainerOrScalarT",
bound=ArrayOrContainerOrScalar)
ArrayOrArithContainerOrScalarT = TypeVar(
"ArrayOrArithContainerOrScalarT",
bound=ArrayOrArithContainerOrScalar)
ArrayOrContainerOrScalarTc = TypeVar(
"ArrayOrContainerOrScalarTc",
ScalarLike, Array, "ArrayContainer", "ArithArrayContainer")
ArrayOrArithContainerOrScalarTc = TypeVar(
"ArrayOrArithContainerOrScalarTc",
ScalarLike, Array, "ArithArrayContainer")
ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike")
NumpyOrContainerOrScalar = Union[np.ndarray, "ArrayContainer", ScalarLike]
# }}}
# {{{ ArrayContext
class ArrayContext(ABC):
r"""
:canonical: arraycontext.ArrayContext
An interface that allows software implementing a numerical algorithm
(such as :class:`~meshmode.discretization.Discretization`) to create and interact
with arrays without knowing their types.
.. versionadded:: 2020.2
.. automethod:: from_numpy
.. automethod:: to_numpy
.. automethod:: call_loopy
.. automethod:: einsum
.. attribute:: np
Provides access to a namespace that serves as a work-alike to
:mod:`numpy`. The actual level of functionality provided is up to the
individual array context implementation, however the functions and
objects available under this namespace must not behave differently
from :mod:`numpy`.
As a baseline, special functions available through :mod:`loopy`
(e.g. ``sin``, ``exp``) are accessible through this interface.
A full list of implemented functionality is given in
:ref:`numpy-coverage`.
Callables accessible through this namespace vectorize over object
arrays, including :class:`arraycontext.ArrayContainer`\ s.
.. attribute:: array_types
A :class:`tuple` of types that are the valid array classes the
context can operate on. However, it is not necessary that *all* the
:class:`ArrayContext`\ 's operations are legal for the types in
*array_types*. Note that this tuple is *only* intended for use
with :func:`isinstance`. Other uses are not allowed. This allows
for 'types' with overridden :meth:`type.__instancecheck__`.
.. automethod:: freeze
.. automethod:: thaw
.. automethod:: freeze_thaw
.. automethod:: tag
.. automethod:: tag_axis
.. automethod:: compile
"""
array_types: tuple[type, ...] = ()
def __init__(self) -> None:
self.np = self._get_fake_numpy_namespace()
@abstractmethod
def _get_fake_numpy_namespace(self) -> Any:
...
def __hash__(self) -> int:
raise TypeError(f"unhashable type: '{type(self).__name__}'")
def zeros(self,
shape: int | tuple[int, ...],
dtype: np.dtype[Any]) -> Array:
warn(f"{type(self).__name__}.zeros is deprecated and will stop "
"working in 2025. Use actx.np.zeros instead.",
DeprecationWarning, stacklevel=2)
return self.np.zeros(shape, dtype)
@overload
def from_numpy(self, array: np.ndarray) -> Array:
...
@overload
def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...
@abstractmethod
def from_numpy(self,
array: NumpyOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
r"""
:returns: the :class:`numpy.ndarray` *array* converted to the
array context's array type. The returned array will be
:meth:`thaw`\ ed. When working with array containers each leaf
must be an :class:`~numpy.ndarray` or scalar, which is then converted
to the context's array type leaving the container structure
intact.
"""
@overload
def to_numpy(self, array: Array) -> np.ndarray:
...
@overload
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...
@abstractmethod
def to_numpy(self,
array: ArrayOrContainerOrScalar
) -> NumpyOrContainerOrScalar:
r"""
:returns: an :class:`numpy.ndarray` for each array recognized by the
context. The input *array* must be :meth:`thaw`\ ed.
When working with array containers each leaf must be one of
the context's array types or a scalar, which is then converted to
an :class:`~numpy.ndarray` leaving the container structure intact.
"""
@abstractmethod
def call_loopy(self,
t_unit: loopy.TranslationUnit,
**kwargs: Any) -> dict[str, Array]:
"""Execute the :mod:`loopy` program *program* on the arguments
*kwargs*.
*program* is a :class:`loopy.LoopKernel` or :class:`loopy.TranslationUnit`.
It is expected to not yet be transformed for execution speed.
It must have :attr:`loopy.Options.return_dict` set.
:return: a :class:`dict` of outputs from the program, each an
array understood by the context.
"""
@abstractmethod
def freeze(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
"""Return a version of the context-defined array *array* that is
'frozen', i.e. suitable for long-term storage and reuse. Frozen arrays
do not support arithmetic. For example, in the context of
:class:`~pyopencl.array.Array`, this might mean stripping the array
of an associated command queue, whereas in a lazily-evaluated context,
it might mean that the array is evaluated and stored.
Freezing makes the array independent of this :class:`ArrayContext`;
it is permitted to :meth:`thaw` it in a different one, as long as that
context understands the array format.
"""
@abstractmethod
def thaw(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
"""Take a 'frozen' array and return a new array representing the data in
*array* that is able to perform arithmetic and other operations, using
the execution resources of this context. In the context of
:class:`~pyopencl.array.Array`, this might mean that the array is
equipped with a command queue, whereas in a lazily-evaluated context,
it might mean that the returned array is a symbol bound to
the data in *array*.
The returned array may not be used with other contexts while thawed.
"""
def freeze_thaw(
self, array: ArrayOrContainerOrScalarT
) -> ArrayOrContainerOrScalarT:
r"""Evaluate an input array or container to "frozen" data return a new
"thawed" array or container representing the evaluation result that is
ready for use. This is a shortcut for calling :meth:`freeze` and
:meth:`thaw`.
This method can be useful in array contexts backed by, e.g.
:mod:`pytato`, to force the evaluation of a built-up array expression
(and thereby avoid reevaluations for expressions built on the array).
"""
return self.thaw(self.freeze(array))
@abstractmethod
def tag(self,
tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
"""If the array type used by the array context is capable of capturing
metadata, return a version of *array* with the *tags* applied. *array*
itself is not modified. When working with array containers, the
tags are applied to each leaf of the container.
See :ref:`metadata` as well as application-specific metadata types.
.. versionadded:: 2021.2
"""
@abstractmethod
def tag_axis(self,
iaxis: int, tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
"""If the array type used by the array context is capable of capturing
metadata, return a version of *array* in which axis number *iaxis* has
the *tags* applied. *array* itself is not modified. When working with
array containers, the tags are applied to each leaf of the container.
See :ref:`metadata` as well as application-specific metadata types.
.. versionadded:: 2021.2
"""
@memoize_method
def _get_einsum_prg(self,
spec: str, arg_names: tuple[str, ...],
tagged: ToTagSetConvertible) -> loopy.TranslationUnit:
import loopy as lp
from loopy.version import MOST_RECENT_LANGUAGE_VERSION
from .loopy import _DEFAULT_LOOPY_OPTIONS
return lp.make_einsum(
spec,
arg_names,
options=_DEFAULT_LOOPY_OPTIONS,
lang_version=MOST_RECENT_LANGUAGE_VERSION,
tags=tagged,
default_order=lp.auto,
default_offset=lp.auto,
)
# This lives here rather than in .np because the interface does not
# agree with numpy's all that well. Why can't it, you ask?
# Well, optimizing generic einsum for OpenCL/GPU execution
# is actually difficult, even in eager mode, and so without added
# metadata describing what's happening, transform_loopy_program
# has a very difficult (hopeless?) job to do.
#
# Unfortunately, the existing metadata support (cf. .tag()) cannot
# help with eager mode execution [1], because, by definition, when the
# result is passed to .tag(), it is already computed.
# That's why einsum's interface here needs to be cluttered with
# metadata, and that's why it can't live under .np.
# [1] https://github.com/inducer/meshmode/issues/177
def einsum(self,
spec: str, *args: Array,
arg_names: tuple[str, ...] | None = None,
tagged: ToTagSetConvertible = ()) -> Array:
"""Computes the result of Einstein summation following the
convention in :func:`numpy.einsum`.
:arg spec: a string denoting the subscripts for
summation as a comma-separated list of subscript labels.
This follows the usual :func:`numpy.einsum` convention.
Note that the explicit indicator `->` for the precise output
form is required.
:arg args: a sequence of array-like operands, whose order matches
the subscript labels provided by *spec*.
:arg arg_names: an optional iterable of string types denoting
the names of the *args*. If *None*, default names will be
generated.
:arg tagged: an optional sequence of :class:`pytools.tag.Tag`
objects specifying the tags to be applied to the operation.
:return: the output of the einsum :mod:`loopy` program
"""
if arg_names is None:
arg_names = tuple(f"arg{i}" for i in range(len(args)))
prg = self._get_einsum_prg(spec, arg_names, tagged)
out_ary = self.call_loopy(
prg, **{arg_names[i]: arg for i, arg in enumerate(args)}
)["out"]
return self.tag(tagged, out_ary)
@abstractmethod
def clone(self) -> Self:
"""If possible, return a version of *self* that is semantically
equivalent (i.e. implements all array operations in the same way)
but is a separate object. May return *self* if that is not possible.
.. note::
The main objective of this semi-documented method is to help
flag errors more clearly when array contexts are mixed that
should not be. For example, at the time of this writing,
:class:`meshmode.meshmode.Discretization` objects have a private
array context that is only to be used for setup-related tasks.
By using :meth:`clone` to make this a separate array context,
and by checking that arithmetic does not mix array contexts,
it becomes easier to detect and flag if unfrozen data attached to a
"setup-only" array context "leaks" into the application.
"""
def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
"""Compiles *f* for repeated use on this array context. *f* is expected
to be a `pure function <https://en.wikipedia.org/wiki/Pure_function>`__
performing an array computation.
Control flow statements (``if``, ``while``) that might take different
paths depending on the data lead to undefined behavior and are illegal.
Any data-dependent control flow must be expressed via array functions,
such as ``actx.np.where``.
*f* may be called on placeholder data, to obtain a representation
of the computation performed, or it may be called as part of the actual
computation, on actual data. If *f* is called on placeholder data,
it may be called only once (or a few times).
:arg f: the function executing the computation.
:return: a function with the same signature as *f*.
"""
return f
# undocumented for now
@property
@abstractmethod
def permits_inplace_modification(self) -> bool:
"""
*True* if the arrays allow in-place modifications.
"""
# undocumented for now
@property
@abstractmethod
def supports_nonscalar_broadcasting(self) -> bool:
"""
*True* if the arrays support non-scalar broadcasting.
"""
# undocumented for now
@property
@abstractmethod
def permits_advanced_indexing(self) -> bool:
"""
*True* if the arrays support :mod:`numpy`'s advanced indexing semantics.
"""
# }}}
# {{{ tagging helpers
def tag_axes(
actx: ArrayContext,
dim_to_tags: Mapping[int, ToTagSetConvertible],
ary: ArrayT) -> ArrayT:
"""
Return a copy of *ary* with the axes in *dim_to_tags* tagged with their
corresponding tags. Equivalent to repeated application of
:meth:`ArrayContext.tag_axis`.
"""
for iaxis, tags in dim_to_tags.items():
ary = actx.tag_axis(iaxis, tags, ary)
return ary
# }}}
class UntransformedCodeWarning(UserWarning):
pass
# vim: foldmethod=marker
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import operator
from abc import ABC, abstractmethod
from typing import Any
import numpy as np
from arraycontext.container import NotAnArrayContainerError, serialize_container
from arraycontext.container.traversal import rec_map_array_container
# {{{ BaseFakeNumpyNamespace
class BaseFakeNumpyNamespace(ABC):
def __init__(self, array_context):
self._array_context = array_context
self.linalg = self._get_fake_numpy_linalg_namespace()
def _get_fake_numpy_linalg_namespace(self):
return BaseFakeNumpyLinalgNamespace(self._array_context)
_numpy_math_functions = frozenset({
# https://numpy.org/doc/stable/reference/routines.math.html
# FIXME: Heads up: not all of these are supported yet.
# But I felt it was important to only dispatch actually existing
# numpy functions to loopy.
# Trigonometric functions
"sin", "cos", "tan", "arcsin", "arccos", "arctan", "hypot", "arctan2",
"degrees", "radians", "unwrap", "deg2rad", "rad2deg",
# Hyperbolic functions
"sinh", "cosh", "tanh", "arcsinh", "arccosh", "arctanh",
# Rounding
"around", "round_", "rint", "fix", "floor", "ceil", "trunc",
# Sums, products, differences
# FIXME: Many of These are reductions or scans.
# "prod", "sum", "nanprod", "nansum", "cumprod", "cumsum", "nancumprod",
# "nancumsum", "diff", "ediff1d", "gradient", "cross", "trapz",
# Exponents and logarithms
"exp", "expm1", "exp2", "log", "log10", "log2", "log1p", "logaddexp",
"logaddexp2",
# Other special functions
"i0", "sinc",
# Floating point routines
"signbit", "copysign", "frexp", "ldexp", "nextafter", "spacing",
# Rational routines
"lcm", "gcd",
# Arithmetic operations
"add", "reciprocal", "positive", "negative", "multiply", "divide", "power",
"subtract", "true_divide", "floor_divide", "float_power", "fmod", "mod",
"modf", "remainder", "divmod",
# Handling complex numbers
"angle", "real", "imag",
# Implemented below:
# "conj", "conjugate",
# Miscellaneous
"convolve", "clip", "sqrt", "cbrt", "square", "absolute", "abs", "fabs",
"sign", "heaviside", "maximum", "fmax", "nan_to_num", "isnan", "minimum",
"fmin",
# FIXME:
# "interp",
})
@abstractmethod
def zeros(self, shape, dtype):
...
@abstractmethod
def zeros_like(self, ary):
...
def conjugate(self, x):
# NOTE: conjugate distributes over object arrays, but it looks for a
# `conjugate` ufunc, while some implementations only have the shorter
# `conj` (e.g. cl.array.Array), so this should work for everybody.
return rec_map_array_container(lambda obj: obj.conj(), x)
conj = conjugate
# {{{ linspace
# based on
# https://github.com/numpy/numpy/blob/v1.25.0/numpy/core/function_base.py#L24-L182
def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis=0):
num = operator.index(num)
if num < 0:
raise ValueError(f"Number of samples, {num}, must be non-negative.")
div = (num - 1) if endpoint else num
# Convert float/complex array scalars to float, gh-3504
# and make sure one can use variables that have an __array_interface__,
# gh-6634
if isinstance(start, self._array_context.array_types):
raise NotImplementedError("start as an actx array")
if isinstance(stop, self._array_context.array_types):
raise NotImplementedError("stop as an actx array")
start = np.array(start) * 1.0
stop = np.array(stop) * 1.0
dt = np.result_type(start, stop, float(num))
if dtype is None:
dtype = dt
integer_dtype = False
else:
integer_dtype = np.issubdtype(dtype, np.integer)
delta = stop - start
y = self.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim)
if div > 0:
step = delta / div
# any_step_zero = _nx.asanyarray(step == 0).any()
any_step_zero = self._array_context.to_numpy(step == 0).any()
if any_step_zero:
delta_actx = self._array_context.from_numpy(delta)
# Special handling for denormal numbers, gh-5437
y = y / div
y = y * delta_actx
else:
step_actx = self._array_context.from_numpy(step)
y = y * step_actx
else:
delta_actx = self._array_context.from_numpy(delta)
# sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0)
# have an undefined step
step = np.nan
# Multiply with delta to allow possible override of output class.
y = y * delta_actx
y += start
# FIXME reenable, without in-place ops
# if endpoint and num > 1:
# y[-1, ...] = stop
if axis != 0:
# y = _nx.moveaxis(y, 0, axis)
raise NotImplementedError("axis != 0")
if integer_dtype:
y = self.floor(y) # pylint: disable=no-member
# FIXME: Use astype
# https://github.com/inducer/pytato/issues/456
if retstep:
return y, step
# return y.astype(dtype), step
else:
return y
# return y.astype(dtype)
# }}}
def arange(self, *args: Any, **kwargs: Any):
raise NotImplementedError
# }}}
# {{{ BaseFakeNumpyLinalgNamespace
def _reduce_norm(actx, arys, ord):
from functools import reduce
from numbers import Number
if ord is None:
ord = 2
# NOTE: these are ordered by an expected usage frequency
if ord == 2:
return actx.np.sqrt(sum(subary*subary for subary in arys))
elif ord == np.inf:
return reduce(actx.np.maximum, arys)
elif ord == -np.inf:
return reduce(actx.np.minimum, arys)
elif isinstance(ord, Number) and ord > 0:
return sum(subary**ord for subary in arys)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
class BaseFakeNumpyLinalgNamespace:
def __init__(self, array_context):
self._array_context = array_context
def norm(self, ary, ord=None):
if np.isscalar(ary):
return abs(ary)
actx = self._array_context
try:
from meshmode.dof_array import DOFArray, flat_norm
except ImportError:
pass
else:
if isinstance(ary, DOFArray):
from warnings import warn
warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
"(DOFArrays use 2D arrays internally, and "
"actx.np.linalg.norm should compute matrix norms of those.) "
"This will stop working in 2022. "
"Use meshmode.dof_array.flat_norm instead.",
DeprecationWarning, stacklevel=2)
return flat_norm(ary, ord=ord)
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError:
pass
else:
return _reduce_norm(actx, [
self.norm(subary, ord=ord) for _, subary in iterable
], ord=ord)
if ord is None:
return self.norm(actx.np.ravel(ary, order="A"), 2)
if len(ary.shape) != 1:
raise NotImplementedError("only vector norms are implemented")
if ary.size == 0:
return ary.dtype.type(0)
from numbers import Number
if ord == 2:
return actx.np.sqrt(actx.np.sum(abs(ary)**2))
if ord == np.inf:
return actx.np.max(abs(ary))
elif ord == -np.inf:
return actx.np.min(abs(ary))
elif isinstance(ord, Number) and ord > 0:
return actx.np.sum(abs(ary)**ord)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
# }}}
# vim: foldmethod=marker
from __future__ import annotations
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
"""
.. currentmodule:: arraycontext
.. autoclass:: EagerJAXArrayContext
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from collections.abc import Callable
import numpy as np
from pytools.tag import ToTagSetConvertible
from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
class EagerJAXArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses
:class:`jax.Array` instances for its base array
class and performs all array operations eagerly. See
:class:`~arraycontext.PytatoJAXArrayContext` for a lazier version.
.. note::
JAX stores a global configuration state in :data:`jax.config`. Callers
are expected to maintain those. Most important for scientific computing
workloads being ``jax_enable_x64``.
"""
def __init__(self) -> None:
super().__init__()
import jax.numpy as jnp
self.array_types = (jnp.ndarray, )
def _get_fake_numpy_namespace(self):
from .fake_numpy import EagerJAXFakeNumpyNamespace
return EagerJAXFakeNumpyNamespace(self)
def _rec_map_container(
self, func: Callable[[Array], Array], array: ArrayOrContainer,
allowed_types: tuple[type, ...] | None = None, *,
default_scalar: ScalarLike | None = None,
strict: bool = False) -> ArrayOrContainer:
if allowed_types is None:
allowed_types = self.array_types
def _wrapper(ary):
if isinstance(ary, allowed_types):
return func(ary)
elif np.isscalar(ary):
if default_scalar is None:
return ary
else:
return np.array(ary).dtype.type(default_scalar)
else:
raise TypeError(
f"{type(self).__name__}.{func.__name__[1:]} invoked with "
f"an unsupported array type: got '{type(ary).__name__}', "
f"but expected one of {allowed_types}")
return rec_map_array_container(_wrapper, array)
# {{{ ArrayContext interface
def from_numpy(self, array):
def _from_numpy(ary):
import jax
return jax.device_put(ary)
return with_array_context(
self._rec_map_container(_from_numpy, array, allowed_types=(np.ndarray,)),
actx=self)
def to_numpy(self, array):
def _to_numpy(ary):
import jax
return jax.device_get(ary)
return with_array_context(
self._rec_map_container(_to_numpy, array),
actx=None)
def freeze(self, array):
def _freeze(ary):
return ary.block_until_ready()
return with_array_context(self._rec_map_container(_freeze, array), actx=None)
def thaw(self, array):
return with_array_context(array, actx=self)
def tag(self, tags: ToTagSetConvertible, array):
# Sorry, not capable.
return array
def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
# TODO: See `jax.experimental.maps.xmap`, probably that should be useful?
return array
def call_loopy(self, t_unit, **kwargs):
raise NotImplementedError(
"Calling loopy on JAX arrays is not supported. Maybe rewrite"
" the loopy kernel as numpy-flavored array operations using"
" ArrayContext.np.")
def einsum(self, spec, *args, arg_names=None, tagged=()):
import jax.numpy as jnp
if arg_names is not None:
from warnings import warn
warn("'arg_names' don't bear any significance in "
f"{type(self).__name__}.", stacklevel=2)
return jnp.einsum(spec, *args)
def clone(self):
return type(self)()
# }}}
# {{{ properties
@property
def permits_inplace_modification(self):
return False
@property
def supports_nonscalar_broadcasting(self):
return True
@property
def permits_advanced_indexing(self):
return True
# }}}
# vim: foldmethod=marker
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial, reduce
import numpy as np
import jax.numpy as jnp
from arraycontext.container import (
NotAnArrayContainerError,
serialize_container,
)
from arraycontext.container.traversal import (
rec_map_array_container,
rec_map_reduce_array_container,
rec_multimap_array_container,
)
from arraycontext.context import Array, ArrayOrContainer
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace
class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
# Everything is implemented in the base class for now.
pass
class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
"""
A :mod:`numpy` mimic for :class:`~arraycontext.EagerJAXArrayContext`.
"""
def _get_fake_numpy_linalg_namespace(self):
return EagerJAXFakeNumpyLinalgNamespace(self._array_context)
def __getattr__(self, name):
return partial(rec_multimap_array_container, getattr(jnp, name))
# NOTE: the order of these follows the order in numpy docs
# NOTE: when adding a function here, also add it to `array_context.rst` docs!
# {{{ array creation routines
def zeros(self, shape, dtype):
return jnp.zeros(shape=shape, dtype=dtype)
def empty_like(self, ary):
from warnings import warn
warn(f"{type(self._array_context).__name__}.np.empty_like is "
"deprecated and will stop working in 2023. Prefer actx.np.zeros_like "
"instead.",
DeprecationWarning, stacklevel=2)
def _empty_like(array):
return self._array_context.empty(array.shape, array.dtype)
return self._array_context._rec_map_container(_empty_like, ary)
def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.np.zeros(array.shape, array.dtype)
return self._array_context._rec_map_container(
_zeros_like, ary, default_scalar=0)
def ones_like(self, ary):
return self.full_like(ary, 1)
def full_like(self, ary, fill_value):
def _full_like(subary):
return jnp.full_like(subary, fill_value)
return self._array_context._rec_map_container(
_full_like, ary, default_scalar=fill_value)
# }}}
# {{{ array manipulation routies
def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: jnp.reshape(ary, newshape, order=order),
a)
def ravel(self, a, order="C"):
"""
.. warning::
Since :func:`jax.numpy.reshape` does not support orders `A`` and
``K``, in such cases we fallback to using ``order = C``.
"""
if order in "AK":
from warnings import warn
warn(f"ravel with order='{order}' not supported by JAX,"
" using order=C.", stacklevel=1)
order = "C"
return rec_map_array_container(
lambda subary: jnp.ravel(subary, order=order), a)
def transpose(self, a, axes=None):
return rec_multimap_array_container(jnp.transpose, a, axes)
def broadcast_to(self, array, shape):
return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(jnp.concatenate, arrays, axis)
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: jnp.stack(arrays=args, axis=axis),
*arrays)
# }}}
# {{{ linear algebra
def vdot(self, x, y, dtype=None):
from arraycontext import rec_multimap_reduce_array_container
def _rec_vdot(ary1, ary2):
common_dtype = np.result_type(ary1, ary2)
if dtype not in (None, common_dtype):
raise NotImplementedError(
f"{type(self).__name__} cannot take dtype in vdot.")
return jnp.vdot(ary1, ary2)
return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
# }}}
# {{{ logic functions
def all(self, a):
return rec_map_reduce_array_container(
partial(reduce, jnp.logical_and), jnp.all, a)
def any(self, a):
return rec_map_reduce_array_container(
partial(reduce, jnp.logical_or), jnp.any, a)
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
actx = self._array_context
# NOTE: not all backends support `bool` properly, so use `int8` instead
true_ary = actx.from_numpy(np.int8(True))
false_ary = actx.from_numpy(np.int8(False))
def rec_equal(x, y):
if type(x) is not type(y):
return false_ary
try:
serialized_x = serialize_container(x)
serialized_y = serialize_container(y)
except NotAnArrayContainerError:
if x.shape != y.shape:
return false_ary
else:
return jnp.all(jnp.equal(x, y))
else:
if len(serialized_x) != len(serialized_y):
return false_ary
return reduce(
jnp.logical_and,
[(true_ary if kx_i == ky_i else false_ary)
and rec_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y, strict=True)],
true_ary)
return rec_equal(a, b)
# }}}
# {{{ mathematical functions
def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(
sum,
partial(jnp.sum, axis=axis, dtype=dtype),
a)
def amin(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
min = amin
def amax(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
max = amax
# }}}
# {{{ sorting, searching and counting
def where(self, criterion, then, else_):
return rec_multimap_array_container(jnp.where, criterion, then, else_)
# }}}
"""
.. currentmodule:: arraycontext
A :mod:`numpy`-based array context.
.. autoclass:: NumpyArrayContext
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from typing import Any, overload
import numpy as np
import loopy as lp
from pytools.tag import ToTagSetConvertible
from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ContainerOrScalarT,
NumpyOrContainerOrScalar,
UntransformedCodeWarning,
)
class NumpyNonObjectArrayMetaclass(type):
def __instancecheck__(cls, instance: Any) -> bool:
return isinstance(instance, np.ndarray) and instance.dtype != object
class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass):
pass
class NumpyArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays.
.. automethod:: __init__
"""
_loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]
def __init__(self) -> None:
super().__init__()
self._loopy_transform_cache = {}
array_types = (NumpyNonObjectArray,)
def _get_fake_numpy_namespace(self):
from .fake_numpy import NumpyFakeNumpyNamespace
return NumpyFakeNumpyNamespace(self)
# {{{ ArrayContext interface
def clone(self):
return type(self)()
@overload
def from_numpy(self, array: np.ndarray) -> Array:
...
@overload
def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...
def from_numpy(self,
array: NumpyOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
return array
@overload
def to_numpy(self, array: Array) -> np.ndarray:
...
@overload
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...
def to_numpy(self,
array: ArrayOrContainerOrScalar
) -> NumpyOrContainerOrScalar:
return array
def call_loopy(
self,
t_unit: lp.TranslationUnit, **kwargs: Any
) -> dict[str, Array]:
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
try:
executor = self._loopy_transform_cache[t_unit]
except KeyError:
executor = self.transform_loopy_program(t_unit).executor()
self._loopy_transform_cache[t_unit] = executor
_, result = executor(**kwargs)
return result
def freeze(self, array):
def _freeze(ary):
return ary
return with_array_context(rec_map_array_container(_freeze, array), actx=None)
def thaw(self, array):
def _thaw(ary):
return ary
return with_array_context(rec_map_array_container(_thaw, array), actx=self)
# }}}
def transform_loopy_program(self, t_unit):
from warnings import warn
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is a no-op and will result in unoptimized C code for"
"the requested optimization, all in a single statement."
"This will work, but is unlikely to be performant."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
UntransformedCodeWarning, stacklevel=2)
return t_unit
def tag(self,
tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
# Numpy doesn't support tagging
return array
def tag_axis(self,
iaxis: int, tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
# Numpy doesn't support tagging
return array
def einsum(self, spec, *args, arg_names=None, tagged=()):
return np.einsum(spec, *args)
@property
def permits_inplace_modification(self):
return True
@property
def supports_nonscalar_broadcasting(self):
return True
@property
def permits_advanced_indexing(self):
return True