Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • inducer/arraycontext
  • kaushikcfd/arraycontext
  • fikl2/arraycontext
3 results
Show changes
Commits on Source (556)
Showing
with 3165 additions and 569 deletions
version: 2
updates:
# Set update schedule for GitHub Actions
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
# vim: sw=4
...@@ -7,14 +7,15 @@ on: ...@@ -7,14 +7,15 @@ on:
jobs: jobs:
autopush: autopush:
name: Automatic push to gitlab.tiker.net name: Automatic push to gitlab.tiker.net
if: startsWith(github.repository, 'inducer/')
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- run: | - run: |
mkdir ~/.ssh && echo -e "Host gitlab.tiker.net\n\tStrictHostKeyChecking no\n" >> ~/.ssh/config curl -L -O https://tiker.net/ci-support-v0
eval $(ssh-agent) && echo "$GITLAB_AUTOPUSH_KEY" | ssh-add - . ./ci-support-v0
git fetch --unshallow mirror_github_to_gitlab
git push "git@gitlab.tiker.net:inducer/$(basename $GITHUB_REPOSITORY).git" main
env: env:
GITLAB_AUTOPUSH_KEY: ${{ secrets.GITLAB_AUTOPUSH_KEY }} GITLAB_AUTOPUSH_KEY: ${{ secrets.GITLAB_AUTOPUSH_KEY }}
......
...@@ -7,27 +7,35 @@ on: ...@@ -7,27 +7,35 @@ on:
schedule: schedule:
- cron: '17 3 * * 0' - cron: '17 3 * * 0'
concurrency:
group: ${{ github.head_ref || github.ref_name }}
cancel-in-progress: true
jobs: jobs:
flake8: typos:
name: Flake8 name: Typos
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- uses: crate-ci/typos@master
ruff:
name: Ruff
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- -
uses: actions/setup-python@v1 uses: actions/setup-python@v5
with:
# matches compat target in setup.py
python-version: '3.6'
- name: "Main Script" - name: "Main Script"
run: | run: |
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh pip install ruff
. ./prepare-and-run-flake8.sh "$(basename $GITHUB_REPOSITORY)" test examples ruff check
pylint: pylint:
name: Pylint name: Pylint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: "Main Script" - name: "Main Script"
run: | run: |
USE_CONDA_BUILD=1 USE_CONDA_BUILD=1
...@@ -38,29 +46,46 @@ jobs: ...@@ -38,29 +46,46 @@ jobs:
name: Mypy name: Mypy
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
-
uses: actions/setup-python@v1
with:
python-version: '3.x'
- name: "Main Script" - name: "Main Script"
run: | run: |
curl -L -O https://tiker.net/ci-support-v0 curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0 . ./ci-support-v0
build_py_project_in_conda_env build_py_project_in_conda_env
python -m pip install mypy python -m pip install mypy pytest
python -m mypy "$(basename $GITHUB_REPOSITORY)" test ./run-mypy.sh
pytest3: pytest3_pocl:
name: Pytest Conda Py3 name: Pytest Conda Py3 POCL
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: "Main Script" - name: "Main Script"
run: | run: |
export MPLBACKEND=Agg curl -L -O https://tiker.net/ci-support-v0
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/ci-support.sh . ./ci-support-v0
. ./ci-support.sh build_py_project_in_conda_env
test_py_project
pytest3_intel_cl:
name: Pytest Conda Py3 Intel
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: "Main Script"
run: |
curl -L -O https://raw.githubusercontent.com/illinois-scicomp/machine-shop-maintenance/main/install-intel-icd.sh
sudo bash ./install-intel-icd.sh
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
echo "- ocl-icd-system" >> "$CONDA_ENVIRONMENT"
sed -i "/pocl/ d" "$CONDA_ENVIRONMENT"
export PYOPENCL_TEST=intel
source /opt/enable-intel-cl.sh
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_conda_env build_py_project_in_conda_env
test_py_project test_py_project
...@@ -68,7 +93,7 @@ jobs: ...@@ -68,7 +93,7 @@ jobs:
name: Examples Conda Py3 name: Examples Conda Py3
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: "Main Script" - name: "Main Script"
run: | run: |
export MPLBACKEND=Agg export MPLBACKEND=Agg
...@@ -80,15 +105,15 @@ jobs: ...@@ -80,15 +105,15 @@ jobs:
name: Documentation name: Documentation
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- -
uses: actions/setup-python@v1 uses: actions/setup-python@v5
with: with:
python-version: '3.x' python-version: '3.x'
- name: "Main Script" - name: "Main Script"
run: | run: |
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/ci-support.sh curl -L -O https://tiker.net/ci-support-v0
. ci-support.sh . ci-support-v0
build_py_project_in_conda_env build_py_project_in_conda_env
conda install graphviz conda install graphviz
...@@ -98,43 +123,23 @@ jobs: ...@@ -98,43 +123,23 @@ jobs:
downstream_tests: downstream_tests:
strategy: strategy:
matrix: matrix:
#downstream_project: [meshmode, grudge, pytential, mirgecom] downstream_project: [meshmode, grudge, mirgecom, mirgecom_examples]
downstream_project: [meshmode, grudge, mirgecom] fail-fast: false
name: Tests for downstream project ${{ matrix.downstream_project }} name: Tests for downstream project ${{ matrix.downstream_project }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: "Main Script" - name: "Main Script"
env: env:
DOWNSTREAM_PROJECT: ${{ matrix.downstream_project }} DOWNSTREAM_PROJECT: ${{ matrix.downstream_project }}
run: | run: |
if test "$DOWNSTREAM_PROJECT" = "mirgecom"; then
git clone "https://github.com/illinois-ceesd/$DOWNSTREAM_PROJECT.git"
else
git clone "https://github.com/inducer/$DOWNSTREAM_PROJECT.git"
fi
cd "$DOWNSTREAM_PROJECT"
echo "*** $DOWNSTREAM_PROJECT version: $(git rev-parse --short HEAD)"
sed -i "/egg=arraycontext/ c git+file://$(readlink -f ..)#egg=arraycontext" requirements.txt
# Avoid slow or complicated tests in downstream projects
export PYTEST_ADDOPTS="-k 'not (slowtest or octave or mpi)'"
if test "$DOWNSTREAM_PROJECT" = "mirgecom"; then
# can't turn off MPI in mirgecom
sudo apt-get update
sudo apt-get install openmpi-bin libopenmpi-dev
export CONDA_ENVIRONMENT=conda-env.yml
export CISUPPORT_PARALLEL_PYTEST=no
else
sed -i "/mpi4py/ d" requirements.txt
fi
curl -L -O https://tiker.net/ci-support-v0 curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0 . ./ci-support-v0
build_py_project_in_conda_env test_downstream "$DOWNSTREAM_PROJECT"
test_py_project
if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then
python ../examples/simple-dg.py --lazy
fi
# vim: sw=4 # vim: sw=4
...@@ -21,3 +21,6 @@ a.out ...@@ -21,3 +21,6 @@ a.out
.pytest_cache .pytest_cache
test/nodal-dg test/nodal-dg
.pylintrc.yml
.run-pylint.py
Python 3 POCL: Python 3 POCL:
script: | script: |
export PY_EXE=python3 export PYOPENCL_TEST=portable:cpu
export PYOPENCL_TEST=portable:pthread export EXTRA_INSTALL="jax[cpu]"
# cython is here because pytential (for now, for TS) depends on it export JAX_PLATFORMS=cpu
export EXTRA_INSTALL="pybind11 cython numpy mako mpi4py oct2py"
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh
. ./build-and-test-py-project.sh . ./build-and-test-py-project.sh
tags: tags:
...@@ -18,12 +17,30 @@ Python 3 POCL: ...@@ -18,12 +17,30 @@ Python 3 POCL:
Python 3 Nvidia Titan V: Python 3 Nvidia Titan V:
script: | script: |
export PY_EXE=python3 curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
export PYOPENCL_TEST=nvi:titan export PYOPENCL_TEST=nvi:titan
export EXTRA_INSTALL="pybind11 cython numpy mako oct2py" build_py_project_in_venv
# cython is here because pytential (for now, for TS) depends on it pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh test_py_project
. ./build-and-test-py-project.sh
tags:
- python3
- nvidia-titan-v
except:
- tags
artifacts:
reports:
junit: test/pytest.xml
Python 3 POCL Nvidia Titan V:
script: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
export PYOPENCL_TEST=port:titan
build_py_project_in_venv
test_py_project
tags: tags:
- python3 - python3
- nvidia-titan-v - nvidia-titan-v
...@@ -36,10 +53,7 @@ Python 3 Nvidia Titan V: ...@@ -36,10 +53,7 @@ Python 3 Nvidia Titan V:
Python 3 POCL Examples: Python 3 POCL Examples:
script: script:
- test -n "$SKIP_EXAMPLES" && exit - test -n "$SKIP_EXAMPLES" && exit
- export PY_EXE=python3 - export PYOPENCL_TEST=portable:cpu
- export PYOPENCL_TEST=portable:pthread
# cython is here because pytential (for now, for TS) depends on it
- export EXTRA_INSTALL="pybind11 cython numpy mako matplotlib"
- curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-py-project-and-run-examples.sh - curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-py-project-and-run-examples.sh
- ". ./build-py-project-and-run-examples.sh" - ". ./build-py-project-and-run-examples.sh"
tags: tags:
...@@ -51,7 +65,11 @@ Python 3 POCL Examples: ...@@ -51,7 +65,11 @@ Python 3 POCL Examples:
Python 3 Conda: Python 3 Conda:
script: | script: |
export MPLBACKEND=Agg export PYOPENCL_TEST=portable:cpu
# Avoid crashes like https://gitlab.tiker.net/inducer/arraycontext/-/jobs/536021
sed -i 's/jax/jax !=0.4.6/' .test-conda-env-py3.yml
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh
. ./build-and-test-py-project-within-miniconda.sh . ./build-and-test-py-project-within-miniconda.sh
tags: tags:
...@@ -63,26 +81,24 @@ Python 3 Conda: ...@@ -63,26 +81,24 @@ Python 3 Conda:
Documentation: Documentation:
script: | script: |
EXTRA_INSTALL="pybind11 cython numpy"
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh
CI_SUPPORT_SPHINX_VERSION_SPECIFIER=">=4.0" CI_SUPPORT_SPHINX_VERSION_SPECIFIER=">=4.0"
. ./build-docs.sh . ./build-docs.sh
tags: tags:
- python3 - python3
Flake8: Ruff:
script: script:
- curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh - pipx install ruff
- . ./prepare-and-run-flake8.sh "$CI_PROJECT_NAME" test examples - ruff check
tags: tags:
- python3 - docker-runner
except: except:
- tags - tags
Pylint: Pylint:
script: | script: |
export PY_EXE=python3 EXTRA_INSTALL="jax[cpu]"
EXTRA_INSTALL="Cython pybind11 numpy mako matplotlib scipy mpi4py oct2py"
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh
. ./prepare-and-run-pylint.sh "$CI_PROJECT_NAME" examples/*.py test/test_*.py . ./prepare-and-run-pylint.sh "$CI_PROJECT_NAME" examples/*.py test/test_*.py
tags: tags:
...@@ -92,12 +108,30 @@ Pylint: ...@@ -92,12 +108,30 @@ Pylint:
Mypy: Mypy:
script: | script: |
EXTRA_INSTALL="mypy pytest"
curl -L -O https://tiker.net/ci-support-v0 curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0 . ./ci-support-v0
build_py_project_in_venv build_py_project_in_venv
python -m pip install mypy ./run-mypy.sh
python -m mypy "$CI_PROJECT_NAME" test
tags: tags:
- python3 - python3
except: except:
- tags - tags
Downstream:
parallel:
matrix:
- DOWNSTREAM_PROJECT: [meshmode, grudge, mirgecom, mirgecom_examples]
tags:
- large-node
- "docker-runner"
script: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
test_downstream "$DOWNSTREAM_PROJECT"
if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then
python ../examples/simple-dg.py --lazy
fi
- arg: ignore
val:
- firedrake
- to_firedrake.py
- from_firedrake.py
- test_firedrake_interop.py
- arg: extension-pkg-whitelist
val: mayavi
...@@ -8,8 +8,10 @@ dependencies: ...@@ -8,8 +8,10 @@ dependencies:
- git - git
- libhwloc=2 - libhwloc=2
- numpy - numpy
- pocl # pocl 3.1 required for full SVM functionality
- pocl>=3.1
- mako - mako
- pyopencl - pyopencl
- islpy - islpy
- pip - pip
- jax
include doc/*.rst
include doc/conf.py
include doc/make.bat
include doc/Makefile
include examples/*.py
...@@ -7,7 +7,7 @@ arraycontext: Choose your favorite ``numpy``-workalike ...@@ -7,7 +7,7 @@ arraycontext: Choose your favorite ``numpy``-workalike
.. image:: https://github.com/inducer/arraycontext/workflows/CI/badge.svg .. image:: https://github.com/inducer/arraycontext/workflows/CI/badge.svg
:alt: Github Build Status :alt: Github Build Status
:target: https://github.com/inducer/arraycontext/actions?query=branch%3Amain+workflow%3ACI :target: https://github.com/inducer/arraycontext/actions?query=branch%3Amain+workflow%3ACI
.. image:: https://badge.fury.io/py/arraycontext.png .. image:: https://badge.fury.io/py/arraycontext.svg
:alt: Python Package Index Release Page :alt: Python Package Index Release Page
:target: https://pypi.org/project/arraycontext/ :target: https://pypi.org/project/arraycontext/
...@@ -17,7 +17,9 @@ implementations for: ...@@ -17,7 +17,9 @@ implementations for:
- numpy - numpy
- `PyOpenCL <https://documen.tician.de/pyopencl/array.html>`__ - `PyOpenCL <https://documen.tician.de/pyopencl/array.html>`__
- `JAX <https://jax.readthedocs.io/en/latest/>`__
- `Pytato <https://documen.tician.de/pytato>`__ (for lazy/deferred evaluation) - `Pytato <https://documen.tician.de/pytato>`__ (for lazy/deferred evaluation)
with backends for ``pyopencl`` and ``jax``.
- Debugging - Debugging
- Profiling - Profiling
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
An array context is an abstraction that helps you dispatch between multiple An array context is an abstraction that helps you dispatch between multiple
implementations of :mod:`numpy`-like :math:`n`-dimensional arrays. implementations of :mod:`numpy`-like :math:`n`-dimensional arrays.
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
...@@ -28,65 +29,147 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -28,65 +29,147 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from .context import ArrayContext
from .metadata import CommonSubexpressionTag, FirstAxisIsElementsTag
from .container import ( from .container import (
ArrayContainer, ArithArrayContainer,
is_array_container, is_array_container_type, ArrayContainer,
get_container_context, get_container_context_recursively, ArrayContainerT,
serialize_container, deserialize_container) NotAnArrayContainerError,
from .container.arithmetic import with_container_arithmetic SerializationKey,
SerializedContainer,
deserialize_container,
get_container_context_opt,
get_container_context_recursively,
get_container_context_recursively_opt,
is_array_container,
is_array_container_type,
register_multivector_as_array_container,
serialize_container,
)
from .container.arithmetic import (
BcastUntilActxArray,
with_container_arithmetic,
)
from .container.dataclass import dataclass_array_container from .container.dataclass import dataclass_array_container
from .container.traversal import ( from .container.traversal import (
map_array_container, flat_size_and_dtype,
multimap_array_container, flatten,
rec_map_array_container, freeze,
rec_multimap_array_container, from_numpy,
mapped_over_array_containers, map_array_container,
multimapped_over_array_containers, map_reduce_array_container,
thaw, freeze, mapped_over_array_containers,
from_numpy, to_numpy) multimap_array_container,
multimap_reduce_array_container,
multimapped_over_array_containers,
outer,
rec_map_array_container,
rec_map_reduce_array_container,
rec_multimap_array_container,
rec_multimap_reduce_array_container,
stringify_array_container_tree,
thaw,
to_numpy,
unflatten,
with_array_context,
)
from .context import (
Array,
ArrayContext,
ArrayOrArithContainer,
ArrayOrArithContainerOrScalar,
ArrayOrArithContainerOrScalarT,
ArrayOrArithContainerT,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ArrayOrContainerT,
ArrayT,
Scalar,
ScalarLike,
tag_axes,
)
from .impl.jax import EagerJAXArrayContext
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .pytest import pytest_generate_tests_for_array_contexts
from .loopy import make_loopy_program from .loopy import make_loopy_program
from .pytest import (
PytestArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
pytest_generate_tests_for_array_contexts,
)
from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
__all__ = ( __all__ = (
"ArrayContext", "ArithArrayContainer",
"Array",
"CommonSubexpressionTag", "ArrayContainer",
"FirstAxisIsElementsTag", "ArrayContainerT",
"ArrayContext",
"ArrayContainer", "ArrayOrArithContainer",
"is_array_container", "is_array_container_type", "ArrayOrArithContainerOrScalar",
"get_container_context", "get_container_context_recursively", "ArrayOrArithContainerOrScalarT",
"serialize_container", "deserialize_container", "ArrayOrArithContainerT",
"with_container_arithmetic", "ArrayOrContainer",
"dataclass_array_container", "ArrayOrContainerOrScalar",
"ArrayOrContainerOrScalarT",
"map_array_container", "multimap_array_container", "ArrayOrContainerT",
"rec_map_array_container", "rec_multimap_array_container", "ArrayT",
"mapped_over_array_containers", "BcastUntilActxArray",
"multimapped_over_array_containers", "CommonSubexpressionTag",
"thaw", "freeze", "EagerJAXArrayContext",
"from_numpy", "to_numpy", "ElementwiseMapKernelTag",
"NotAnArrayContainerError",
"PyOpenCLArrayContext", "PytatoArrayContext", "NumpyArrayContext",
"PyOpenCLArrayContext",
"make_loopy_program", "PytatoJAXArrayContext",
"PytatoPyOpenCLArrayContext",
"pytest_generate_tests_for_array_contexts" "PytestArrayContextFactory",
) "PytestPyOpenCLArrayContextFactory",
"Scalar",
"ScalarLike",
def _acf(): "SerializationKey",
"SerializedContainer",
"dataclass_array_container",
"deserialize_container",
"flat_size_and_dtype",
"flatten",
"freeze",
"from_numpy",
"get_container_context_opt",
"get_container_context_recursively",
"get_container_context_recursively_opt",
"is_array_container",
"is_array_container_type",
"make_loopy_program",
"map_array_container",
"map_reduce_array_container",
"mapped_over_array_containers",
"multimap_array_container",
"multimap_reduce_array_container",
"multimapped_over_array_containers",
"outer",
"pytest_generate_tests_for_array_contexts",
"rec_map_array_container",
"rec_map_reduce_array_container",
"rec_multimap_array_container",
"rec_multimap_reduce_array_container",
"register_multivector_as_array_container",
"serialize_container",
"stringify_array_container_tree",
"tag_axes",
"thaw",
"to_numpy",
"unflatten",
"with_array_context",
"with_container_arithmetic",
)
# {{{ deprecation handling
def _deprecated_acf():
"""A tiny undocumented function to pass to tests that take an ``actx_factory`` """A tiny undocumented function to pass to tests that take an ``actx_factory``
argument when running them from the command line. argument when running them from the command line.
""" """
...@@ -96,4 +179,27 @@ def _acf(): ...@@ -96,4 +179,27 @@ def _acf():
queue = cl.CommandQueue(context) queue = cl.CommandQueue(context)
return PyOpenCLArrayContext(queue) return PyOpenCLArrayContext(queue)
_depr_name_to_replacement_and_obj = {
"get_container_context": (
"get_container_context_opt",
get_container_context_opt, 2022),
}
def __getattr__(name):
replacement_and_obj = _depr_name_to_replacement_and_obj.get(name)
if replacement_and_obj is not None:
replacement, obj, year = replacement_and_obj
from warnings import warn
warn(f"'arraycontext.{name}' is deprecated. "
f"Use '{replacement}' instead. "
f"'arraycontext.{name}' will continue to work until {year}.",
DeprecationWarning, stacklevel=2)
return obj
else:
raise AttributeError(name)
# }}}
# vim: foldmethod=marker # vim: foldmethod=marker
...@@ -3,25 +3,58 @@ ...@@ -3,25 +3,58 @@
""" """
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
.. autoclass:: ArrayContainer
.. autoclass:: ArithArrayContainer
.. class:: ArrayContainerT .. class:: ArrayContainerT
:canonical: arraycontext.container.ArrayContainerT
:class:`~typing.TypeVar` for array container-like objects. A type variable with a lower bound of :class:`ArrayContainer`.
.. autoclass:: ArrayContainer .. autoexception:: NotAnArrayContainerError
Serialization/deserialization Serialization/deserialization
----------------------------- -----------------------------
.. autofunction:: is_array_container
.. autoclass:: SerializationKey
.. autoclass:: SerializedContainer
.. autofunction:: is_array_container_type
.. autofunction:: serialize_container .. autofunction:: serialize_container
.. autofunction:: deserialize_container .. autofunction:: deserialize_container
Context retrieval Context retrieval
----------------- -----------------
.. autofunction:: get_container_context .. autofunction:: get_container_context_opt
.. autofunction:: get_container_context_recursively .. autofunction:: get_container_context_recursively
.. autofunction:: get_container_context_recursively_opt
:class:`~pymbolic.geometric_algebra.MultiVector` support
---------------------------------------------------------
.. autofunction:: register_multivector_as_array_container
.. currentmodule:: arraycontext.container
Canonical locations for type annotations
----------------------------------------
.. class:: ArrayContainerT
:canonical: arraycontext.ArrayContainerT
.. class:: ArrayOrContainerT
:canonical: arraycontext.ArrayOrContainerT
.. class:: SerializationKey
:canonical: arraycontext.SerializationKey
.. class:: SerializedContainer
:canonical: arraycontext.SerializedContainer
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
...@@ -47,19 +80,30 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -47,19 +80,30 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from collections.abc import Hashable, Sequence
from functools import singledispatch from functools import singledispatch
from arraycontext.context import ArrayContext from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar
from typing import Any, Iterable, Tuple, TypeVar, Optional
# For use in singledispatch type annotations, because sphinx can't figure out
# what 'np' is.
import numpy
import numpy as np import numpy as np
from typing_extensions import Self
from arraycontext.context import ArrayContext, ArrayOrScalar
ArrayContainerT = TypeVar("ArrayContainerT") if TYPE_CHECKING:
from pymbolic.geometric_algebra import MultiVector
from arraycontext import ArrayOrContainer
# {{{ ArrayContainer # {{{ ArrayContainer
class ArrayContainer: class ArrayContainer(Protocol):
r""" """
A generic container for the array type supported by the A protocol for generic containers of the array type supported by the
:class:`ArrayContext`. :class:`ArrayContext`.
The functionality required for the container to operated is supplied via The functionality required for the container to operated is supplied via
...@@ -70,13 +114,11 @@ class ArrayContainer: ...@@ -70,13 +114,11 @@ class ArrayContainer:
of the array. of the array.
* :func:`deserialize_container` for deserialization, which constructs a * :func:`deserialize_container` for deserialization, which constructs a
container from a set of components. container from a set of components.
* :func:`get_container_context` retrieves the :class:`ArrayContext` from * :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
a container, if it has one. a container, if it has one.
This allows enumeration of the component arrays in a container and the This allows enumeration of the component arrays in a container and the
construction of modified containers from an iterable of those component arrays. construction of modified containers from an iterable of those component arrays.
:func:`is_array_container` will return *True* for types that have
a container serialization function registered.
Packages may register their own types as array containers. They must not Packages may register their own types as array containers. They must not
register other types (e.g. :class:`list`) as array containers. register other types (e.g. :class:`list`) as array containers.
...@@ -92,67 +134,140 @@ class ArrayContainer: ...@@ -92,67 +134,140 @@ class ArrayContainer:
.. note:: .. note::
This class is used in type annotation. Inheriting from it confers no This class is used in type annotation and as a marker of array container
special meaning or behavior. attributes for :func:`~arraycontext.dataclass_array_container`.
As a protocol, it is not intended as a superclass.
""" """
# Array containers do not need to have any particular features, so this
# protocol is deliberately empty.
# This *is* used as a type annotation in dataclasses that are processed
# by dataclass_array_container, where it's used to recognize attributes
# that are container-typed.
class ArithArrayContainer(ArrayContainer, Protocol):
"""
A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
"""
# This is loose and permissive, assuming that any array can be added
# to any container. The alternative would be to plaster type-ignores
# on all those uses. Achieving typing precision on what broadcasting is
# allowable seems like a huge endeavor and is likely not feasible without
# a mypy plugin. Maybe some day? -AK, November 2024
def __neg__(self) -> Self: ...
def __abs__(self) -> Self: ...
def __add__(self, other: ArrayOrScalar | Self) -> Self: ...
def __radd__(self, other: ArrayOrScalar | Self) -> Self: ...
def __sub__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ...
def __mul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ...
def __pow__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...
ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
class NotAnArrayContainerError(TypeError):
""":class:`TypeError` subclass raised when an array container is expected."""
SerializationKey: TypeAlias = Hashable
SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]]
@singledispatch @singledispatch
def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]: def serialize_container(
r"""Serialize the array container into an iterable over its components. ary: ArrayContainer) -> SerializedContainer:
r"""Serialize the array container into a sequence over its components.
The order of the components and their identifiers are entirely under The order of the components and their identifiers are entirely under
the control of the container class. the control of the container class. However, the order is required to be
deterministic, i.e. two calls to :func:`serialize_container` on
array containers of the same types with the same number of
sub-arrays must result in a sequence with the keys in the same
order.
If *ary* is mutable, the serialization function is not required to ensure If *ary* is mutable, the serialization function is not required to ensure
that the serialization result reflects the array state at the time of the that the serialization result reflects the array state at the time of the
call to :func:`serialize_container`. call to :func:`serialize_container`.
:returns: an :class:`Iterable` of 2-tuples where the first :returns: a :class:`Sequence` of 2-tuples where the first
entry is an identifier for the component and the second entry entry is an identifier for the component and the second entry
is an array-like component of the :class:`ArrayContainer`. is an array-like component of the :class:`ArrayContainer`.
Components can themselves be :class:`ArrayContainer`\ s, allowing Components can themselves be :class:`ArrayContainer`\ s, allowing
for arbitrarily nested structures. The identifiers need to be hashable for arbitrarily nested structures. The identifiers need to be hashable
but are otherwise treated as opaque. but are otherwise treated as opaque.
""" """
raise NotImplementedError(type(ary).__name__) raise NotAnArrayContainerError(
f"'{type(ary).__name__}' cannot be serialized as a container")
@singledispatch @singledispatch
def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) -> Any: def deserialize_container(
"""Deserialize an iterable into an array container. template: ArrayContainerT,
serialized: SerializedContainer) -> ArrayContainerT:
"""Deserialize a sequence into an array container following a *template*.
:param template: an instance of an existing object that :param template: an instance of an existing object that
can be used to aid in the deserialization. For a similar choice can be used to aid in the deserialization. For a similar choice
see :attr:`~numpy.class.__array_finalize__`. see :attr:`~numpy.class.__array_finalize__`.
:param iterable: an iterable that mirrors the output of :param serialized: a sequence that mirrors the output of
:meth:`serialize_container`. :meth:`serialize_container`.
""" """
raise NotImplementedError(type(template).__name__) raise NotAnArrayContainerError(
f"'{type(template).__name__}' cannot be deserialized as a container")
def is_array_container_type(cls: type) -> bool: def is_array_container_type(cls: type) -> bool:
""" """
:returns: *True* if the type *cls* has a registered implementation of :returns: *True* if the type *cls* has a registered implementation of
:func:`serialize_container`, or if it is an :class:`ArrayContainer`. :func:`serialize_container`, or if it is an :class:`ArrayContainer`.
.. warning::
Not all instances of a type that this function labels an array container
must automatically be array containers. For example, while this
function will say that :class:`numpy.ndarray` is an array container
type, only object arrays *actually are* array containers.
""" """
assert isinstance(cls, type), f"must pass a {type!r}, not a '{cls!r}'"
return ( return (
cls is ArrayContainer cls is ArrayContainer
or (serialize_container.dispatch(cls) or (serialize_container.dispatch(cls)
is not serialize_container.__wrapped__)) # type: ignore is not serialize_container.__wrapped__)) # type:ignore[attr-defined]
def is_array_container(ary: Any) -> bool: def is_array_container(ary: object) -> bool:
""" """
:returns: *True* if the instance *ary* has a registered implementation of :returns: *True* if the instance *ary* has a registered implementation of
:func:`serialize_container`. :func:`serialize_container`.
""" """
from warnings import warn
warn("is_array_container is deprecated and will be removed in 2022. "
"If you must know precisely whether something is an array container, "
"try serializing it and catch NotAnArrayContainerError. For a "
"cheaper option, see is_array_container_type.",
DeprecationWarning, stacklevel=2)
return (serialize_container.dispatch(ary.__class__) return (serialize_container.dispatch(ary.__class__)
is not serialize_container.__wrapped__) # type: ignore is not serialize_container.__wrapped__ # type:ignore[attr-defined]
# numpy values with scalar elements aren't array containers
and not (isinstance(ary, np.ndarray)
and ary.dtype.kind != "O")
)
@singledispatch @singledispatch
def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]: def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None:
"""Retrieves the :class:`ArrayContext` from the container, if any. """Retrieves the :class:`ArrayContext` from the container, if any.
This function is not recursive, so it will only search at the root level This function is not recursive, so it will only search at the root level
...@@ -167,25 +282,37 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]: ...@@ -167,25 +282,37 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]:
# {{{ object arrays as array containers # {{{ object arrays as array containers
@serialize_container.register(np.ndarray) @serialize_container.register(np.ndarray)
def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]: def _serialize_ndarray_container(
ary: numpy.ndarray) -> SerializedContainer:
if ary.dtype.char != "O": if ary.dtype.char != "O":
raise ValueError( raise NotAnArrayContainerError(
f"only object arrays are supported, given dtype '{ary.dtype}'") f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'")
return np.ndenumerate(ary) # special-cased for speed
if ary.ndim == 1:
return [(i, ary[i]) for i in range(ary.shape[0])]
elif ary.ndim == 2:
return [((i, j), ary[i, j])
for i in range(ary.shape[0])
for j in range(ary.shape[1])
]
else:
return list(np.ndenumerate(ary))
@deserialize_container.register(np.ndarray) @deserialize_container.register(np.ndarray)
def _deserialize_ndarray_container( # https://github.com/python/mypy/issues/13040
template: np.ndarray, def _deserialize_ndarray_container( # type: ignore[misc]
iterable: Iterable[Tuple[Any, Any]]) -> np.ndarray: template: numpy.ndarray,
serialized: SerializedContainer) -> numpy.ndarray:
# disallow subclasses # disallow subclasses
assert type(template) is np.ndarray assert type(template) is np.ndarray
assert template.dtype.char == "O" assert template.dtype.char == "O"
result = type(template)(template.shape, dtype=object) result = type(template)(template.shape, dtype=object)
for i, subary in iterable: for i, subary in serialized:
result[i] = subary # FIXME: numpy annotations don't seem to handle object arrays very well
result[i] = subary # type: ignore[call-overload]
return result return result
...@@ -194,37 +321,101 @@ def _deserialize_ndarray_container( ...@@ -194,37 +321,101 @@ def _deserialize_ndarray_container(
# {{{ get_container_context_recursively # {{{ get_container_context_recursively
def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: def get_container_context_recursively_opt(
ary: ArrayContainer) -> ArrayContext | None:
"""Walks the :class:`ArrayContainer` hierarchy to find an """Walks the :class:`ArrayContainer` hierarchy to find an
:class:`ArrayContext` associated with it. :class:`ArrayContext` associated with it.
If different components that have different array contexts are found at If different components that have different array contexts are found at
any level, an assertion error is raised. any level, an assertion error is raised.
"""
actx = None
if not is_array_container(ary):
return actx
Returns *None* if no array context was found.
"""
# try getting the array context directly # try getting the array context directly
actx = get_container_context(ary) actx = get_container_context_opt(ary)
if actx is not None: if actx is not None:
return actx return actx
for _, subary in serialize_container(ary): try:
context = get_container_context_recursively(subary) iterable = serialize_container(ary)
if context is None: except NotAnArrayContainerError:
continue return actx
else:
for _, subary in iterable:
context = get_container_context_recursively_opt(subary)
if context is None:
continue
if not __debug__:
return context
elif actx is None:
actx = context
else:
assert actx is context
return actx
if not __debug__: def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None:
return context """Walks the :class:`ArrayContainer` hierarchy to find an
elif actx is None: :class:`ArrayContext` associated with it.
actx = context
else: If different components that have different array contexts are found at
assert actx is context any level, an assertion error is raised.
Raises an error if no array container is found.
"""
actx = get_container_context_recursively_opt(ary)
if actx is None:
# raise ValueError("no array context was found")
from warnings import warn
warn("No array context was found. This will be an error starting in "
"July of 2022. If you would like the function to return "
"None if no array context was found, use "
"get_container_context_recursively_opt.",
DeprecationWarning, stacklevel=2)
return actx return actx
# }}} # }}}
# {{{ MultiVector support, see pymbolic.geometric_algebra
# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic.
# (Though clearly there exists a dependency via loopy.)
def _serialize_multivec_as_container(mv: MultiVector) -> SerializedContainer:
return list(mv.data.items())
# FIXME: Ignored due to https://github.com/python/mypy/issues/13040
def _deserialize_multivec_as_container( # type: ignore[misc]
template: MultiVector,
serialized: SerializedContainer) -> MultiVector:
from pymbolic.geometric_algebra import MultiVector
return MultiVector(dict(serialized), space=template.space)
def _get_container_context_opt_from_multivec(mv: MultiVector) -> None:
return None
def register_multivector_as_array_container() -> None:
"""Registers :class:`~pymbolic.geometric_algebra.MultiVector` as an
:class:`ArrayContainer`. This function may be called multiple times. The
second and subsequent calls have no effect.
"""
from pymbolic.geometric_algebra import MultiVector
if MultiVector not in serialize_container.registry:
serialize_container.register(MultiVector)(_serialize_multivec_as_container)
deserialize_container.register(MultiVector)(
_deserialize_multivec_as_container)
get_container_context_opt.register(MultiVector)(
_get_container_context_opt_from_multivec)
assert MultiVector in serialize_container.registry
# }}}
# vim: foldmethod=marker # vim: foldmethod=marker
This diff is collapsed.
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
.. autofunction:: dataclass_array_container .. autofunction:: dataclass_array_container
""" """
from __future__ import annotations
__copyright__ = """ __copyright__ = """
...@@ -30,34 +31,170 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -30,34 +31,170 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from dataclasses import fields from collections.abc import Mapping, Sequence
from dataclasses import fields, is_dataclass
from typing import NamedTuple, Union, get_args, get_origin
from arraycontext.container import is_array_container_type from arraycontext.container import is_array_container_type
# {{{ dataclass containers # {{{ dataclass containers
class _Field(NamedTuple):
"""Small lookalike for :class:`dataclasses.Field`."""
init: bool
name: str
type: type
def is_array_type(tp: type) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)
def dataclass_array_container(cls: type) -> type: def dataclass_array_container(cls: type) -> type:
"""A class decorator that makes the class to which it is applied a """A class decorator that makes the class to which it is applied an
:class:`ArrayContainer` by registering appropriate implementations of :class:`ArrayContainer` by registering appropriate implementations of
:func:`serialize_container` and :func:`deserialize_container`. :func:`serialize_container` and :func:`deserialize_container`.
*cls* must be a :func:`~dataclasses.dataclass`. *cls* must be a :func:`~dataclasses.dataclass`.
Attributes that are not array containers are allowed. In order to decide Attributes that are not array containers are allowed. In order to decide
whether an attribute is an array container, the declared attribute type whether an attribute is an array container, the declared attribute type
is checked by the criteria from :func:`is_array_container`. is checked by the criteria from :func:`is_array_container_type`. This
includes some support for type annotations:
* a :class:`typing.Union` of array containers is considered an array container.
* other type annotations, e.g. :class:`typing.Optional`, are not considered
array containers, even if they wrap one.
.. note::
When type annotations are strings (e.g. because of
``from __future__ import annotations``),
this function relies on :func:`inspect.get_annotations`
(with ``eval_str=True``) to obtain type annotations. This
means that *cls* must live in a module that is importable.
""" """
from dataclasses import is_dataclass
from types import GenericAlias, UnionType
assert is_dataclass(cls) assert is_dataclass(cls)
array_fields = [ def is_array_field(f: _Field) -> bool:
f for f in fields(cls) if is_array_container_type(f.type)] field_type = f.type
non_array_fields = [
f for f in fields(cls) if not is_array_container_type(f.type)] # NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
#
# `Optional[ArrayContainer]` is not allowed, since `None` is not
# handled by `with_container_arithmetic`, which is the common case
# for current container usage. Other type annotations, e.g.
# `Tuple[Container, Container]`, are also not allowed, as they do not
# work with `with_container_arithmetic`.
#
# This is not set in stone, but mostly driven by current usage!
origin = get_origin(field_type)
# NOTE: `UnionType` is returned when using `Type1 | Type2`
if origin in (Union, UnionType):
if all(is_array_type(arg) for arg in get_args(field_type)):
return True
else:
raise TypeError(
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")
# NOTE: this should never happen due to using `inspect.get_annotations`
assert not isinstance(field_type, str)
if __debug__:
if not f.init:
raise ValueError(
f"Field with 'init=False' not allowed: '{f.name}'")
# NOTE:
# * `GenericAlias` catches typed `list`, `tuple`, etc.
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
# * `_SpecialForm` catches `Any`, `Literal`, etc.
from typing import ( # type: ignore[attr-defined]
_BaseGenericAlias,
_SpecialForm,
)
if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"Typing annotation not supported on field '{f.name}': "
f"'{field_type!r}'")
if not isinstance(field_type, type):
raise TypeError(
f"Field '{f.name}' not an instance of 'type': "
f"'{field_type!r}'")
return is_array_type(field_type)
from pytools import partition
array_fields = _get_annotated_fields(cls)
array_fields, non_array_fields = partition(is_array_field, array_fields)
if not array_fields: if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type " raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator") "in order to use the 'dataclass_array_container' decorator")
return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
def _get_annotated_fields(cls: type) -> Sequence[_Field]:
"""Get a list of fields in the class *cls* with evaluated types.
If any of the fields in *cls* have type annotations that are strings, e.g.
from using ``from __future__ import annotations``, this function evaluates
them using :func:`inspect.get_annotations`. Note that this requires the class
to live in a module that is importable.
:return: a list of fields.
"""
from inspect import get_annotations
result = []
cls_ann: Mapping[str, type] | None = None
for field in fields(cls):
field_type_or_str = field.type
if isinstance(field_type_or_str, str):
if cls_ann is None:
cls_ann = get_annotations(cls, eval_str=True)
field_type = cls_ann[field.name]
else:
field_type = field_type_or_str
result.append(_Field(init=field.init, name=field.name, type=field_type))
return result
def _inject_dataclass_serialization(
cls: type,
array_fields: Sequence[_Field],
non_array_fields: Sequence[_Field]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
This function modifies *cls* in place, so the returned value is the same
object with additional functionality.
:arg array_fields: fields of the given dataclass *cls* which are considered
array containers and should be serialized.
:arg non_array_fields: remaining fields of the dataclass *cls* which are
copied over from the template array in deserialization.
"""
assert is_dataclass(cls)
serialize_expr = ", ".join( serialize_expr = ", ".join(
f"({f.name!r}, ary.{f.name})" for f in array_fields) f"({f.name!r}, ary.{f.name})" for f in array_fields)
template_kwargs = ", ".join( template_kwargs = ", ".join(
...@@ -106,7 +243,8 @@ def dataclass_array_container(cls: type) -> type: ...@@ -106,7 +243,8 @@ def dataclass_array_container(cls: type) -> type:
""") """)
exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": serialize_code} exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": serialize_code}
exec(compile(serialize_code, "<generated code>", "exec"), exec_dict) exec(compile(serialize_code, f"<container serialization for {cls.__name__}>",
"exec"), exec_dict)
return cls return cls
......
This diff is collapsed.
This diff is collapsed.
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
""" """
...@@ -23,21 +26,25 @@ THE SOFTWARE. ...@@ -23,21 +26,25 @@ THE SOFTWARE.
""" """
import operator
from abc import ABC, abstractmethod
from typing import Any
import numpy as np import numpy as np
from arraycontext.container import is_array_container, serialize_container
from arraycontext.container.traversal import ( from arraycontext.container import NotAnArrayContainerError, serialize_container
rec_map_array_container, multimapped_over_array_containers) from arraycontext.container.traversal import rec_map_array_container
# {{{ BaseFakeNumpyNamespace # {{{ BaseFakeNumpyNamespace
class BaseFakeNumpyNamespace: class BaseFakeNumpyNamespace(ABC):
def __init__(self, array_context): def __init__(self, array_context):
self._array_context = array_context self._array_context = array_context
self.linalg = self._get_fake_numpy_linalg_namespace() self.linalg = self._get_fake_numpy_linalg_namespace()
def _get_fake_numpy_linalg_namespace(self): def _get_fake_numpy_linalg_namespace(self):
return BaseFakeNumpyLinalgNamespace(self.array_context) return BaseFakeNumpyLinalgNamespace(self._array_context)
_numpy_math_functions = frozenset({ _numpy_math_functions = frozenset({
# https://numpy.org/doc/stable/reference/routines.math.html # https://numpy.org/doc/stable/reference/routines.math.html
...@@ -86,76 +93,20 @@ class BaseFakeNumpyNamespace: ...@@ -86,76 +93,20 @@ class BaseFakeNumpyNamespace:
# Miscellaneous # Miscellaneous
"convolve", "clip", "sqrt", "cbrt", "square", "absolute", "abs", "fabs", "convolve", "clip", "sqrt", "cbrt", "square", "absolute", "abs", "fabs",
"sign", "heaviside", "maximum", "fmax", "nan_to_num", "sign", "heaviside", "maximum", "fmax", "nan_to_num", "isnan", "minimum",
"fmin",
# FIXME: # FIXME:
# "interp", # "interp",
}) })
_numpy_to_c_arc_functions = { @abstractmethod
"arcsin": "asin", def zeros(self, shape, dtype):
"arccos": "acos", ...
"arctan": "atan",
"arctan2": "atan2",
"arcsinh": "asinh",
"arccosh": "acosh",
"arctanh": "atanh",
}
_c_to_numpy_arc_functions = {c_name: numpy_name
for numpy_name, c_name in _numpy_to_c_arc_functions.items()}
def __getattr__(self, name):
def loopy_implemented_elwise_func(*args):
actx = self._array_context
# FIXME: Maybe involve loopy type inference?
result = actx.empty(args[0].shape, args[0].dtype)
prg = actx._get_scalar_func_loopy_program(
c_name, nargs=len(args), naxes=len(args[0].shape))
actx.call_loopy(prg, out=result,
**{"inp%d" % i: arg for i, arg in enumerate(args)})
return result
if name in self._c_to_numpy_arc_functions:
from warnings import warn
warn(f"'{name}' in ArrayContext.np is deprecated. "
"Use '{c_to_numpy_arc_functions[name]}' as in numpy. "
"The old name will stop working in 2021.",
DeprecationWarning, stacklevel=3)
# normalize to C names anyway
c_name = self._numpy_to_c_arc_functions.get(name, name)
# limit which functions we try to hand off to loopy
if name in self._numpy_math_functions:
return multimapped_over_array_containers(loopy_implemented_elwise_func)
else:
raise AttributeError(name)
def _new_like(self, ary, alloc_like):
from numbers import Number
if isinstance(ary, np.ndarray) and ary.dtype.char == "O":
# NOTE: we don't want to match numpy semantics on object arrays,
# e.g. `np.zeros_like(x)` returns `array([0, 0, ...], dtype=object)`
# FIXME: what about object arrays nested in an ArrayContainer?
raise NotImplementedError("operation not implemented for object arrays")
elif is_array_container(ary):
return rec_map_array_container(alloc_like, ary)
elif isinstance(ary, Number):
# NOTE: `np.zeros_like(x)` returns `array(x, shape=())`, which
# is best implemented by concrete array contexts, if at all
raise NotImplementedError("operation not implemented for scalars")
else:
return alloc_like(ary)
def empty_like(self, ary):
return self._new_like(ary, self._array_context.empty_like)
@abstractmethod
def zeros_like(self, ary): def zeros_like(self, ary):
return self._new_like(ary, self._array_context.zeros_like) ...
def conjugate(self, x): def conjugate(self, x):
# NOTE: conjugate distributes over object arrays, but it looks for a # NOTE: conjugate distributes over object arrays, but it looks for a
...@@ -165,41 +116,170 @@ class BaseFakeNumpyNamespace: ...@@ -165,41 +116,170 @@ class BaseFakeNumpyNamespace:
conj = conjugate conj = conjugate
# {{{ linspace
# based on
# https://github.com/numpy/numpy/blob/v1.25.0/numpy/core/function_base.py#L24-L182
def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis=0):
num = operator.index(num)
if num < 0:
raise ValueError(f"Number of samples, {num}, must be non-negative.")
div = (num - 1) if endpoint else num
# Convert float/complex array scalars to float, gh-3504
# and make sure one can use variables that have an __array_interface__,
# gh-6634
if isinstance(start, self._array_context.array_types):
raise NotImplementedError("start as an actx array")
if isinstance(stop, self._array_context.array_types):
raise NotImplementedError("stop as an actx array")
start = np.array(start) * 1.0
stop = np.array(stop) * 1.0
dt = np.result_type(start, stop, float(num))
if dtype is None:
dtype = dt
integer_dtype = False
else:
integer_dtype = np.issubdtype(dtype, np.integer)
delta = stop - start
y = self.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim)
if div > 0:
step = delta / div
# any_step_zero = _nx.asanyarray(step == 0).any()
any_step_zero = self._array_context.to_numpy(step == 0).any()
if any_step_zero:
delta_actx = self._array_context.from_numpy(delta)
# Special handling for denormal numbers, gh-5437
y = y / div
y = y * delta_actx
else:
step_actx = self._array_context.from_numpy(step)
y = y * step_actx
else:
delta_actx = self._array_context.from_numpy(delta)
# sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0)
# have an undefined step
step = np.nan
# Multiply with delta to allow possible override of output class.
y = y * delta_actx
y += start
# FIXME reenable, without in-place ops
# if endpoint and num > 1:
# y[-1, ...] = stop
if axis != 0:
# y = _nx.moveaxis(y, 0, axis)
raise NotImplementedError("axis != 0")
if integer_dtype:
y = self.floor(y) # pylint: disable=no-member
# FIXME: Use astype
# https://github.com/inducer/pytato/issues/456
if retstep:
return y, step
# return y.astype(dtype), step
else:
return y
# return y.astype(dtype)
# }}}
def arange(self, *args: Any, **kwargs: Any):
raise NotImplementedError
# }}} # }}}
# {{{ BaseFakeNumpyLinalgNamespace # {{{ BaseFakeNumpyLinalgNamespace
def _reduce_norm(actx, arys, ord):
from functools import reduce
from numbers import Number
if ord is None:
ord = 2
# NOTE: these are ordered by an expected usage frequency
if ord == 2:
return actx.np.sqrt(sum(subary*subary for subary in arys))
elif ord == np.inf:
return reduce(actx.np.maximum, arys)
elif ord == -np.inf:
return reduce(actx.np.minimum, arys)
elif isinstance(ord, Number) and ord > 0:
return sum(subary**ord for subary in arys)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
class BaseFakeNumpyLinalgNamespace: class BaseFakeNumpyLinalgNamespace:
def __init__(self, array_context): def __init__(self, array_context):
self._array_context = array_context self._array_context = array_context
def norm(self, ary, ord=None): def norm(self, ary, ord=None):
from numbers import Number if np.isscalar(ary):
if isinstance(ary, Number):
return abs(ary) return abs(ary)
if is_array_container(ary): actx = self._array_context
import numpy.linalg as la
return la.norm( try:
[self.norm(subary, ord=ord) from meshmode.dof_array import DOFArray, flat_norm
for _, subary in serialize_container(ary)], except ImportError:
ord=ord) pass
else:
if isinstance(ary, DOFArray):
from warnings import warn
warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
"(DOFArrays use 2D arrays internally, and "
"actx.np.linalg.norm should compute matrix norms of those.) "
"This will stop working in 2022. "
"Use meshmode.dof_array.flat_norm instead.",
DeprecationWarning, stacklevel=2)
return flat_norm(ary, ord=ord)
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError:
pass
else:
return _reduce_norm(actx, [
self.norm(subary, ord=ord) for _, subary in iterable
], ord=ord)
if ord is None:
return self.norm(actx.np.ravel(ary, order="A"), 2)
if len(ary.shape) != 1: if len(ary.shape) != 1:
raise NotImplementedError("only vector norms are implemented") raise NotImplementedError("only vector norms are implemented")
if ary.size == 0: if ary.size == 0:
return 0 return ary.dtype.type(0)
from numbers import Number
if ord == 2:
return actx.np.sqrt(actx.np.sum(abs(ary)**2))
if ord == np.inf: if ord == np.inf:
return self._array_context.np.max(abs(ary)) return actx.np.max(abs(ary))
elif ord == -np.inf:
return actx.np.min(abs(ary))
elif isinstance(ord, Number) and ord > 0: elif isinstance(ord, Number) and ord > 0:
return self._array_context.np.sum(abs(ary)**ord)**(1/ord) return actx.np.sum(abs(ary)**ord)**(1/ord)
else: else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}") raise NotImplementedError(f"unsupported value of 'ord': {ord}")
# }}} # }}}
......
from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
""" """
...@@ -21,12 +24,3 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, ...@@ -21,12 +24,3 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
def _is_meshmode_dofarray(x):
try:
from meshmode.dof_array import DOFArray
except ImportError:
return False
else:
return isinstance(x, DOFArray)
"""
.. currentmodule:: arraycontext
.. autoclass:: EagerJAXArrayContext
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from collections.abc import Callable
import numpy as np
from pytools.tag import ToTagSetConvertible
from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
class EagerJAXArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses
:class:`jax.Array` instances for its base array
class and performs all array operations eagerly. See
:class:`~arraycontext.PytatoJAXArrayContext` for a lazier version.
.. note::
JAX stores a global configuration state in :data:`jax.config`. Callers
are expected to maintain those. Most important for scientific computing
workloads being ``jax_enable_x64``.
"""
def __init__(self) -> None:
super().__init__()
import jax.numpy as jnp
self.array_types = (jnp.ndarray, )
def _get_fake_numpy_namespace(self):
from .fake_numpy import EagerJAXFakeNumpyNamespace
return EagerJAXFakeNumpyNamespace(self)
def _rec_map_container(
self, func: Callable[[Array], Array], array: ArrayOrContainer,
allowed_types: tuple[type, ...] | None = None, *,
default_scalar: ScalarLike | None = None,
strict: bool = False) -> ArrayOrContainer:
if allowed_types is None:
allowed_types = self.array_types
def _wrapper(ary):
if isinstance(ary, allowed_types):
return func(ary)
elif np.isscalar(ary):
if default_scalar is None:
return ary
else:
return np.array(ary).dtype.type(default_scalar)
else:
raise TypeError(
f"{type(self).__name__}.{func.__name__[1:]} invoked with "
f"an unsupported array type: got '{type(ary).__name__}', "
f"but expected one of {allowed_types}")
return rec_map_array_container(_wrapper, array)
# {{{ ArrayContext interface
def from_numpy(self, array):
def _from_numpy(ary):
import jax
return jax.device_put(ary)
return with_array_context(
self._rec_map_container(_from_numpy, array, allowed_types=(np.ndarray,)),
actx=self)
def to_numpy(self, array):
def _to_numpy(ary):
import jax
return jax.device_get(ary)
return with_array_context(
self._rec_map_container(_to_numpy, array),
actx=None)
def freeze(self, array):
def _freeze(ary):
return ary.block_until_ready()
return with_array_context(self._rec_map_container(_freeze, array), actx=None)
def thaw(self, array):
return with_array_context(array, actx=self)
def tag(self, tags: ToTagSetConvertible, array):
# Sorry, not capable.
return array
def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
# TODO: See `jax.experimental.maps.xmap`, probably that should be useful?
return array
def call_loopy(self, t_unit, **kwargs):
raise NotImplementedError(
"Calling loopy on JAX arrays is not supported. Maybe rewrite"
" the loopy kernel as numpy-flavored array operations using"
" ArrayContext.np.")
def einsum(self, spec, *args, arg_names=None, tagged=()):
import jax.numpy as jnp
if arg_names is not None:
from warnings import warn
warn("'arg_names' don't bear any significance in "
f"{type(self).__name__}.", stacklevel=2)
return jnp.einsum(spec, *args)
def clone(self):
return type(self)()
# }}}
# {{{ properties
@property
def permits_inplace_modification(self):
return False
@property
def supports_nonscalar_broadcasting(self):
return True
@property
def permits_advanced_indexing(self):
return True
# }}}
# vim: foldmethod=marker
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial, reduce
import numpy as np
import jax.numpy as jnp
from arraycontext.container import (
NotAnArrayContainerError,
serialize_container,
)
from arraycontext.container.traversal import (
rec_map_array_container,
rec_map_reduce_array_container,
rec_multimap_array_container,
)
from arraycontext.context import Array, ArrayOrContainer
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace
class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
# Everything is implemented in the base class for now.
pass
class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
"""
A :mod:`numpy` mimic for :class:`~arraycontext.EagerJAXArrayContext`.
"""
def _get_fake_numpy_linalg_namespace(self):
return EagerJAXFakeNumpyLinalgNamespace(self._array_context)
def __getattr__(self, name):
return partial(rec_multimap_array_container, getattr(jnp, name))
# NOTE: the order of these follows the order in numpy docs
# NOTE: when adding a function here, also add it to `array_context.rst` docs!
# {{{ array creation routines
def zeros(self, shape, dtype):
return jnp.zeros(shape=shape, dtype=dtype)
def empty_like(self, ary):
from warnings import warn
warn(f"{type(self._array_context).__name__}.np.empty_like is "
"deprecated and will stop working in 2023. Prefer actx.np.zeros_like "
"instead.",
DeprecationWarning, stacklevel=2)
def _empty_like(array):
return self._array_context.empty(array.shape, array.dtype)
return self._array_context._rec_map_container(_empty_like, ary)
def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.np.zeros(array.shape, array.dtype)
return self._array_context._rec_map_container(
_zeros_like, ary, default_scalar=0)
def ones_like(self, ary):
return self.full_like(ary, 1)
def full_like(self, ary, fill_value):
def _full_like(subary):
return jnp.full_like(subary, fill_value)
return self._array_context._rec_map_container(
_full_like, ary, default_scalar=fill_value)
# }}}
# {{{ array manipulation routies
def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: jnp.reshape(ary, newshape, order=order),
a)
def ravel(self, a, order="C"):
"""
.. warning::
Since :func:`jax.numpy.reshape` does not support orders `A`` and
``K``, in such cases we fallback to using ``order = C``.
"""
if order in "AK":
from warnings import warn
warn(f"ravel with order='{order}' not supported by JAX,"
" using order=C.", stacklevel=1)
order = "C"
return rec_map_array_container(
lambda subary: jnp.ravel(subary, order=order), a)
def transpose(self, a, axes=None):
return rec_multimap_array_container(jnp.transpose, a, axes)
def broadcast_to(self, array, shape):
return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(jnp.concatenate, arrays, axis)
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: jnp.stack(arrays=args, axis=axis),
*arrays)
# }}}
# {{{ linear algebra
def vdot(self, x, y, dtype=None):
from arraycontext import rec_multimap_reduce_array_container
def _rec_vdot(ary1, ary2):
common_dtype = np.result_type(ary1, ary2)
if dtype not in (None, common_dtype):
raise NotImplementedError(
f"{type(self).__name__} cannot take dtype in vdot.")
return jnp.vdot(ary1, ary2)
return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
# }}}
# {{{ logic functions
def all(self, a):
return rec_map_reduce_array_container(
partial(reduce, jnp.logical_and), jnp.all, a)
def any(self, a):
return rec_map_reduce_array_container(
partial(reduce, jnp.logical_or), jnp.any, a)
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
actx = self._array_context
# NOTE: not all backends support `bool` properly, so use `int8` instead
true_ary = actx.from_numpy(np.int8(True))
false_ary = actx.from_numpy(np.int8(False))
def rec_equal(x, y):
if type(x) is not type(y):
return false_ary
try:
serialized_x = serialize_container(x)
serialized_y = serialize_container(y)
except NotAnArrayContainerError:
if x.shape != y.shape:
return false_ary
else:
return jnp.all(jnp.equal(x, y))
else:
if len(serialized_x) != len(serialized_y):
return false_ary
return reduce(
jnp.logical_and,
[(true_ary if kx_i == ky_i else false_ary)
and rec_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y, strict=True)],
true_ary)
return rec_equal(a, b)
# }}}
# {{{ mathematical functions
def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(
sum,
partial(jnp.sum, axis=axis, dtype=dtype),
a)
def amin(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
min = amin
def amax(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
max = amax
# }}}
# {{{ sorting, searching and counting
def where(self, criterion, then, else_):
return rec_multimap_array_container(jnp.where, criterion, then, else_)
# }}}
"""
.. currentmodule:: arraycontext
A :mod:`numpy`-based array context.
.. autoclass:: NumpyArrayContext
"""
from __future__ import annotations
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from typing import Any, overload
import numpy as np
import loopy as lp
from pytools.tag import ToTagSetConvertible
from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ContainerOrScalarT,
NumpyOrContainerOrScalar,
UntransformedCodeWarning,
)
class NumpyNonObjectArrayMetaclass(type):
def __instancecheck__(cls, instance: Any) -> bool:
return isinstance(instance, np.ndarray) and instance.dtype != object
class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass):
pass
class NumpyArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays.
.. automethod:: __init__
"""
_loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]
def __init__(self) -> None:
super().__init__()
self._loopy_transform_cache = {}
array_types = (NumpyNonObjectArray,)
def _get_fake_numpy_namespace(self):
from .fake_numpy import NumpyFakeNumpyNamespace
return NumpyFakeNumpyNamespace(self)
# {{{ ArrayContext interface
def clone(self):
return type(self)()
@overload
def from_numpy(self, array: np.ndarray) -> Array:
...
@overload
def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...
def from_numpy(self,
array: NumpyOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
return array
@overload
def to_numpy(self, array: Array) -> np.ndarray:
...
@overload
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...
def to_numpy(self,
array: ArrayOrContainerOrScalar
) -> NumpyOrContainerOrScalar:
return array
def call_loopy(
self,
t_unit: lp.TranslationUnit, **kwargs: Any
) -> dict[str, Array]:
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
try:
executor = self._loopy_transform_cache[t_unit]
except KeyError:
executor = self.transform_loopy_program(t_unit).executor()
self._loopy_transform_cache[t_unit] = executor
_, result = executor(**kwargs)
return result
def freeze(self, array):
def _freeze(ary):
return ary
return with_array_context(rec_map_array_container(_freeze, array), actx=None)
def thaw(self, array):
def _thaw(ary):
return ary
return with_array_context(rec_map_array_container(_thaw, array), actx=self)
# }}}
def transform_loopy_program(self, t_unit):
from warnings import warn
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is a no-op and will result in unoptimized C code for"
"the requested optimization, all in a single statement."
"This will work, but is unlikely to be performant."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
UntransformedCodeWarning, stacklevel=2)
return t_unit
def tag(self,
tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
# Numpy doesn't support tagging
return array
def tag_axis(self,
iaxis: int, tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
# Numpy doesn't support tagging
return array
def einsum(self, spec, *args, arg_names=None, tagged=()):
return np.einsum(spec, *args)
@property
def permits_inplace_modification(self):
return True
@property
def supports_nonscalar_broadcasting(self):
return True
@property
def permits_advanced_indexing(self):
return True