From 600c659d05ccc4d65ac8887858ba80a6d5305c44 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Wed, 11 Jan 2023 17:23:00 +0200 Subject: [PATCH] enable and apply isort --- arraycontext/__init__.py | 73 +++++++------------ arraycontext/container/__init__.py | 8 +- arraycontext/container/arithmetic.py | 3 +- arraycontext/container/dataclass.py | 4 +- arraycontext/container/traversal.py | 19 ++--- arraycontext/context.py | 10 ++- arraycontext/fake_numpy.py | 3 +- arraycontext/impl/jax/__init__.py | 7 +- arraycontext/impl/jax/fake_numpy.py | 14 ++-- arraycontext/impl/pyopencl/__init__.py | 11 +-- arraycontext/impl/pyopencl/fake_numpy.py | 18 ++--- .../impl/pyopencl/taggable_cl_array.py | 4 +- arraycontext/impl/pytato/__init__.py | 55 ++++++++------ arraycontext/impl/pytato/compile.py | 38 +++++----- arraycontext/impl/pytato/fake_numpy.py | 17 ++--- arraycontext/impl/pytato/utils.py | 14 ++-- arraycontext/loopy.py | 8 +- arraycontext/metadata.py | 7 +- arraycontext/pytest.py | 20 +++-- setup.cfg | 8 ++ test/test_arraycontext.py | 43 +++++------ test/test_pytato_arraycontext.py | 14 ++-- test/test_utils.py | 13 ++-- 23 files changed, 205 insertions(+), 206 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 06e0b96..00e542d 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -29,61 +29,38 @@ THE SOFTWARE. """ import sys -from .context import ( - ArrayContext, - - Scalar, ScalarLike, - Array, ArrayT, - ArrayOrContainer, ArrayOrContainerT, - ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, - - tag_axes) - -from .transform_metadata import (CommonSubexpressionTag, - ElementwiseMapKernelTag) - -# deprecated, remove in 2022. -from .metadata import _FirstAxisIsElementsTag from .container import ( - ArrayContainer, ArrayContainerT, - NotAnArrayContainerError, - is_array_container, is_array_container_type, - get_container_context_opt, - get_container_context_recursively, get_container_context_recursively_opt, - serialize_container, deserialize_container, - register_multivector_as_array_container) + ArrayContainer, ArrayContainerT, NotAnArrayContainerError, deserialize_container, + get_container_context_opt, get_container_context_recursively, + get_container_context_recursively_opt, is_array_container, + is_array_container_type, register_multivector_as_array_container, + serialize_container) from .container.arithmetic import with_container_arithmetic from .container.dataclass import dataclass_array_container - from .container.traversal import ( - map_array_container, - multimap_array_container, - rec_map_array_container, - rec_multimap_array_container, - mapped_over_array_containers, - multimapped_over_array_containers, - map_reduce_array_container, - multimap_reduce_array_container, - rec_map_reduce_array_container, - rec_multimap_reduce_array_container, - thaw, freeze, - flatten, unflatten, flat_size_and_dtype, - from_numpy, to_numpy, - outer, with_array_context) - -from .impl.pyopencl import PyOpenCLArrayContext -from .impl.pytato import (PytatoPyOpenCLArrayContext, - PytatoJAXArrayContext) + flat_size_and_dtype, flatten, freeze, from_numpy, map_array_container, + map_reduce_array_container, mapped_over_array_containers, + multimap_array_container, multimap_reduce_array_container, + multimapped_over_array_containers, outer, rec_map_array_container, + rec_map_reduce_array_container, rec_multimap_array_container, + rec_multimap_reduce_array_container, thaw, to_numpy, unflatten, + with_array_context) +from .context import ( + Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar, + ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike, + tag_axes) from .impl.jax import EagerJAXArrayContext - -from .pytest import ( - PytestArrayContextFactory, - PytestPyOpenCLArrayContextFactory, - pytest_generate_tests_for_array_contexts, - pytest_generate_tests_for_pyopencl_array_context) - +from .impl.pyopencl import PyOpenCLArrayContext +from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext from .loopy import make_loopy_program +# deprecated, remove in 2022. +from .metadata import _FirstAxisIsElementsTag +from .pytest import ( + PytestArrayContextFactory, PytestPyOpenCLArrayContextFactory, + pytest_generate_tests_for_array_contexts, + pytest_generate_tests_for_pyopencl_array_context) +from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag __all__ = ( diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 71bccee..fcb130f 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -69,17 +69,19 @@ THE SOFTWARE. """ from functools import singledispatch -from arraycontext.context import ArrayContext -from typing import Any, Iterable, Tuple, Optional, TypeVar, Protocol, TYPE_CHECKING -import numpy as np +from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple, TypeVar # For use in singledispatch type annotations, because sphinx can't figure out # what 'np' is. import numpy +import numpy as np + +from arraycontext.context import ArrayContext if TYPE_CHECKING: from pymbolic.geometric_algebra import MultiVector + from arraycontext import ArrayOrContainer diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 5e2ade2..148d34b 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -7,6 +7,7 @@ import enum + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ @@ -31,8 +32,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union from warnings import warn -from typing import Any, Callable, Optional, Tuple, TypeVar, Union, Type import numpy as np diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index ee5c2f0..e9ab38d 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -30,9 +30,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from dataclasses import Field, fields, is_dataclass from typing import Tuple, Union, get_args, get_origin -from dataclasses import Field, is_dataclass, fields from arraycontext.container import is_array_container_type @@ -95,7 +95,7 @@ def dataclass_array_container(cls: type) -> type: # NOTE: # * `_BaseGenericAlias` catches `List`, `Tuple`, etc. # * `_SpecialForm` catches `Any`, `Literal`, etc. - from typing import ( # type: ignore[attr-defined] + from typing import ( # type: ignore[attr-defined] _BaseGenericAlias, _SpecialForm) if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)): # NOTE: anything except a Union is not allowed diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index c41b464..048f6a1 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -41,6 +41,7 @@ Algebraic operations from __future__ import annotations + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ @@ -65,22 +66,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable, Iterable, List, Optional, Union, Tuple, cast -from functools import update_wrapper, partial, singledispatch +from functools import partial, singledispatch, update_wrapper +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast from warnings import warn import numpy as np -from arraycontext.context import ( - ArrayT, ArrayOrContainer, ArrayOrContainerT, - ArrayOrContainerOrScalar, ScalarLike, - ArrayContext, Array -) from arraycontext.container import ( - NotAnArrayContainerError, - ArrayContainer, - serialize_container, deserialize_container, - get_container_context_recursively_opt) + ArrayContainer, NotAnArrayContainerError, deserialize_container, + get_container_context_recursively_opt, serialize_container) +from arraycontext.context import ( + Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar, + ArrayOrContainerT, ArrayT, ScalarLike) # {{{ array container traversal helpers diff --git a/arraycontext/context.py b/arraycontext/context.py index 2378550..f844106 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -160,15 +160,18 @@ THE SOFTWARE. from abc import ABC, abstractmethod from typing import ( - Any, Callable, Dict, Optional, Tuple, Union, Mapping, Protocol, TypeVar, - TYPE_CHECKING) + TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, Protocol, Tuple, TypeVar, + Union) import numpy as np + from pytools import memoize_method from pytools.tag import ToTagSetConvertible + if TYPE_CHECKING: import loopy + from arraycontext.container import ArrayContainer @@ -426,8 +429,9 @@ class ArrayContext(ABC): spec: str, arg_names: Tuple[str, ...], tagged: ToTagSetConvertible) -> "loopy.TranslationUnit": import loopy as lp - from .loopy import _DEFAULT_LOOPY_OPTIONS from loopy.version import MOST_RECENT_LANGUAGE_VERSION + + from .loopy import _DEFAULT_LOOPY_OPTIONS return lp.make_einsum( spec, arg_names, diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index c3e37f8..cf416fa 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -24,6 +24,7 @@ THE SOFTWARE. import numpy as np + from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import rec_map_array_container @@ -105,8 +106,8 @@ class BaseFakeNumpyNamespace: # {{{ BaseFakeNumpyLinalgNamespace def _reduce_norm(actx, arys, ord): - from numbers import Number from functools import reduce + from numbers import Number if ord is None: ord = 2 diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 4aa30c2..e680f7e 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -32,9 +32,10 @@ from typing import Callable, Optional, Tuple import numpy as np from pytools.tag import ToTagSetConvertible -from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike -from arraycontext.container.traversal import (with_array_context, - rec_map_array_container) + +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) +from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike class EagerJAXArrayContext(ArrayContext): diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index daaf880..0955820 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -23,17 +23,15 @@ THE SOFTWARE. """ from functools import partial, reduce -import numpy as np import jax.numpy as jnp +import numpy as np -from arraycontext.fake_numpy import ( - BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, - ) -from arraycontext.container.traversal import ( - rec_multimap_array_container, rec_map_array_container, - rec_map_reduce_array_container, - ) from arraycontext.container import NotAnArrayContainerError, serialize_container +from arraycontext.container.traversal import ( + rec_map_array_container, rec_map_reduce_array_container, + rec_multimap_array_container) +from arraycontext.fake_numpy import ( + BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace) class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index aced309..4064689 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -28,21 +28,21 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple from warnings import warn -from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING import numpy as np from pytools.tag import ToTagSetConvertible -from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike -from arraycontext.container.traversal import (rec_map_array_container, - with_array_context) +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) +from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike if TYPE_CHECKING: - import pyopencl import loopy as lp + import pyopencl # {{{ PyOpenCLArrayContext @@ -287,6 +287,7 @@ class PyOpenCLArrayContext(ArrayContext): wait_event_queue.pop(0).wait() import arraycontext.impl.pyopencl.taggable_cl_array as tga + # FIXME: Inherit loopy tags for these arrays return {name: tga.to_tagged_cl_array(ary) for name, ary in result.items()} diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index a0180e7..71af653 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -26,24 +26,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from functools import partial, reduce import operator +from functools import partial, reduce import numpy as np -from arraycontext.fake_numpy import ( - BaseFakeNumpyLinalgNamespace - ) -from arraycontext.loopy import ( - LoopyBasedFakeNumpyNamespace - ) from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( - rec_map_array_container, - rec_multimap_array_container, - rec_map_reduce_array_container, - rec_multimap_reduce_array_container, - ) + rec_map_array_container, rec_map_reduce_array_container, + rec_multimap_array_container, rec_multimap_reduce_array_container) +from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace +from arraycontext.loopy import LoopyBasedFakeNumpyNamespace + try: import pyopencl as cl # noqa: F401 diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index 49ae08b..0361200 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -9,10 +9,10 @@ from dataclasses import dataclass from typing import Any, Dict, FrozenSet, Optional, Tuple import numpy as np -import pyopencl.array as cla +import pyopencl.array as cla from pytools import memoize -from pytools.tag import Taggable, Tag, ToTagSetConvertible +from pytools.tag import Tag, Taggable, ToTagSetConvertible # {{{ utils diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index c3e4462..8fe3055 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -44,27 +44,30 @@ THE SOFTWARE. import abc import sys -from typing import (Any, Callable, Union, Tuple, Type, FrozenSet, Dict, Optional, - TYPE_CHECKING) +from typing import ( + TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Optional, Tuple, Type, Union) import numpy as np -from pytools.tag import ToTagSetConvertible, normalize_tags, Tag -from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike -from arraycontext.container.traversal import (rec_map_array_container, - with_array_context) -from arraycontext.metadata import NameHint from pytools import memoize_method +from pytools.tag import Tag, ToTagSetConvertible, normalize_tags + +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) +from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike +from arraycontext.metadata import NameHint + if TYPE_CHECKING: - import pytato import pyopencl as cl + import pytato if getattr(sys, "_BUILDING_SPHINX_DOCS", False): import pyopencl as cl # noqa: F811 - import logging + + logger = logging.getLogger(__name__) @@ -126,8 +129,8 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): """ super().__init__() - import pytato as pt import loopy as lp + import pytato as pt self._freeze_prg_cache: Dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {} self._dag_transform_cache: Dict[ pt.DictOfNamedArrays, @@ -292,8 +295,8 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): except ImportError: self.using_svm = False - import pytato as pt import pyopencl.array as cla + import pytato as pt super().__init__(compile_trace_callback=compile_trace_callback) self.queue = queue @@ -316,6 +319,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): default_scalar: Optional[ScalarLike] = None, strict: bool = False) -> ArrayOrContainer: import pytato as pt + import arraycontext.impl.pyopencl.taggable_cl_array as tga if allowed_types is None: @@ -357,6 +361,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): def from_numpy(self, array): import pytato as pt + import arraycontext.impl.pyopencl.taggable_cl_array as tga def _from_numpy(ary): @@ -415,8 +420,8 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): logger.info(f"limiting argument buffer size for {dev} to {limit} bytes") - from arraycontext.impl.pytato.utils import \ - ArgSizeLimitingPytatoLoopyPyOpenCLTarget + from arraycontext.impl.pytato.utils import ( + ArgSizeLimitingPytatoLoopyPyOpenCLTarget) return ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit) else: return super().get_target() @@ -425,15 +430,15 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): if np.isscalar(array): return array - import pytato as pt import pyopencl.array as cla + import pytato as pt from arraycontext.container.traversal import rec_keyed_map_array_container - from arraycontext.impl.pytato.utils import (_normalize_pt_expr, - get_cl_axes_from_pt_axes) - from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array, - TaggableCLArray) + from arraycontext.impl.pyopencl.taggable_cl_array import ( + TaggableCLArray, to_tagged_cl_array) from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + from arraycontext.impl.pytato.utils import ( + _normalize_pt_expr, get_cl_axes_from_pt_axes) array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray, pt.Array]] = {} key_to_frozen_subary: Dict[str, TaggableCLArray] = {} @@ -549,8 +554,9 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): def thaw(self, array): import pytato as pt - from .utils import get_pt_axes_from_cl_axes + import arraycontext.impl.pyopencl.taggable_cl_array as tga + from .utils import get_pt_axes_from_cl_axes def _thaw(ary): return pt.make_data_wrapper(ary.with_queue(self.queue), @@ -579,8 +585,9 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): def call_loopy(self, program, **kwargs): import pytato as pt - from pytato.scalar_expr import SCALAR_CLASSES from pytato.loopy import call_loopy + from pytato.scalar_expr import SCALAR_CLASSES + from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray entrypoint = program.default_entrypoint.name @@ -617,6 +624,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): def einsum(self, spec, *args, arg_names=None, tagged=()): import pytato as pt + import arraycontext.impl.pyopencl.taggable_cl_array as tga if arg_names is None: @@ -685,8 +693,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): representation. This interface should be considered unstable. """ - import pytato as pt import jax.numpy as jnp + + import pytato as pt super().__init__(compile_trace_callback=compile_trace_callback) self.array_types = (pt.Array, jnp.ndarray) @@ -731,6 +740,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): def from_numpy(self, array): import jax + import pytato as pt def _from_numpy(ary): @@ -754,9 +764,10 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): if np.isscalar(array): return array + import jax.numpy as jnp + import pytato as pt - import jax.numpy as jnp from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.compile import _ary_container_key_stringifier diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 6fb56cc..d7e0416 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -29,26 +29,26 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from arraycontext.context import ArrayT -from arraycontext.container import ArrayContainer, is_array_container_type -from arraycontext.impl.pytato import (_BasePytatoArrayContext, - PytatoJAXArrayContext, - PytatoPyOpenCLArrayContext) -from arraycontext.container.traversal import rec_keyed_map_array_container - import abc -import numpy as np -from typing import Any, Callable, Tuple, Dict, Mapping, FrozenSet, Type +import itertools +import logging from dataclasses import dataclass, field -from pyrsistent import pmap, PMap +from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Type + +import numpy as np +from pyrsistent import PMap, pmap import pytato as pt -import itertools +from pytools import ProcessLogger from pytools.tag import Tag -from pytools import ProcessLogger +from arraycontext.container import ArrayContainer, is_array_container_type +from arraycontext.container.traversal import rec_keyed_map_array_container +from arraycontext.context import ArrayT +from arraycontext.impl.pytato import ( + PytatoJAXArrayContext, PytatoPyOpenCLArrayContext, _BasePytatoArrayContext) + -import logging logger = logging.getLogger(__name__) @@ -185,8 +185,9 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): :meth:`PytatoPyOpenCLArrayContext.transform_dag`. """ import pyopencl.array as cla - from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array, - TaggableCLArray) + + from arraycontext.impl.pyopencl.taggable_cl_array import ( + TaggableCLArray, to_tagged_cl_array) if isinstance(ary, pt.Array): dag = pt.make_dict_of_named_arrays({"_actx_out": ary}) # Transform the DAG to give metadata inference a chance to do its job @@ -390,9 +391,8 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): if prg_id is None: prg_id = self.f - from pytato.target.loopy import BoundPyOpenCLProgram - import loopy as lp + from pytato.target.loopy import BoundPyOpenCLProgram self.actx._compile_trace_callback( prg_id, "pre_transform_dag", dict_of_named_arrays) @@ -632,8 +632,8 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: - from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array from .utils import get_cl_axes_from_pt_axes + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array input_kwargs_for_loopy = _args_to_device_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) @@ -674,8 +674,8 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: - from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array from .utils import get_cl_axes_from_pt_axes + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array input_kwargs_for_loopy = _args_to_device_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index e17f8ee..4dad159 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -25,19 +25,14 @@ from functools import partial, reduce import numpy as np -from arraycontext.fake_numpy import ( - BaseFakeNumpyLinalgNamespace - ) -from arraycontext.loopy import ( - LoopyBasedFakeNumpyNamespace - ) +import pytato as pt + from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( - rec_map_array_container, - rec_multimap_array_container, - rec_map_reduce_array_container, - ) -import pytato as pt + rec_map_array_container, rec_map_reduce_array_container, + rec_multimap_array_container) +from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace +from arraycontext.loopy import LoopyBasedFakeNumpyNamespace class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 9d77202..e0af81c 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -23,15 +23,17 @@ THE SOFTWARE. """ -from typing import Any, Dict, Set, Tuple, Mapping, Optional, TYPE_CHECKING -from pytools import memoize_method +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Set, Tuple -from pytato.array import SizeParam, Placeholder, make_placeholder, Axis as PtAxis -from pytato.array import Array, DataWrapper, DictOfNamedArrays +from pytato.array import ( + Array, Axis as PtAxis, DataWrapper, DictOfNamedArrays, Placeholder, SizeParam, + make_placeholder) +from pytato.target.loopy import LoopyPyOpenCLTarget from pytato.transform import CopyMapper -from pytools import UniqueNameGenerator +from pytools import UniqueNameGenerator, memoize_method + from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis -from pytato.target.loopy import LoopyPyOpenCLTarget + if TYPE_CHECKING: import loopy as lp diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index 1f90318..f8e54b5 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -28,12 +28,15 @@ THE SOFTWARE. """ import numpy as np + import loopy as lp from loopy.version import MOST_RECENT_LANGUAGE_VERSION -from arraycontext.fake_numpy import BaseFakeNumpyNamespace -from arraycontext.container.traversal import multimapped_over_array_containers from pytools import memoize_in +from arraycontext.container.traversal import multimapped_over_array_containers +from arraycontext.fake_numpy import BaseFakeNumpyNamespace + + # {{{ loopy _DEFAULT_LOOPY_OPTIONS = lp.Options( @@ -89,6 +92,7 @@ def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes): domain_bset, = domain.get_basic_sets() import loopy as lp + from .loopy import make_loopy_program from arraycontext.transform_metadata import ElementwiseMapKernelTag return make_loopy_program( diff --git a/arraycontext/metadata.py b/arraycontext/metadata.py index 39934d6..95fc639 100644 --- a/arraycontext/metadata.py +++ b/arraycontext/metadata.py @@ -29,9 +29,10 @@ THE SOFTWARE. import sys from dataclasses import dataclass -from pytools.tag import Tag, UniqueTag from warnings import warn +from pytools.tag import Tag, UniqueTag + @dataclass(frozen=True) class NameHint(UniqueTag): @@ -52,8 +53,8 @@ class NameHint(UniqueTag): # {{{ deprecation handling try: - from meshmode.transform_metadata import FirstAxisIsElementsTag \ - as _FirstAxisIsElementsTag + from meshmode.transform_metadata import ( + FirstAxisIsElementsTag as _FirstAxisIsElementsTag) except ImportError: # placeholder in case meshmode is too old to have it. class _FirstAxisIsElementsTag(Tag): # type: ignore[no-redef] diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index f84915b..4fce588 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -63,7 +63,7 @@ class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory): @classmethod def is_available(cls) -> bool: try: - import pyopencl # noqa: F401 + import pyopencl # noqa: F401 return True except ImportError: return False @@ -79,6 +79,7 @@ class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory): collect() import pyopencl as cl + # On Intel CPU CL, existence of a command queue does not ensure that # the context survives. ctx = cl.Context([self.device]) @@ -134,8 +135,8 @@ class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory @classmethod def is_available(cls) -> bool: try: - import pyopencl # noqa: F401 - import pytato # noqa: F401 + import pyopencl # noqa: F401 + import pytato # noqa: F401 return True except ImportError: return False @@ -182,14 +183,15 @@ class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory): @classmethod def is_available(cls) -> bool: try: - import jax # noqa: F401 + import jax # noqa: F401 return True except ImportError: return False def __call__(self): - from arraycontext import EagerJAXArrayContext from jax.config import config + + from arraycontext import EagerJAXArrayContext config.update("jax_enable_x64", True) return EagerJAXArrayContext() @@ -204,15 +206,17 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): @classmethod def is_available(cls) -> bool: try: - import jax # noqa: F401 - import pytato # noqa: F401 + import jax # noqa: F401 + + import pytato # noqa: F401 return True except ImportError: return False def __call__(self): - from arraycontext import PytatoJAXArrayContext from jax.config import config + + from arraycontext import PytatoJAXArrayContext config.update("jax_enable_x64", True) return PytatoJAXArrayContext() diff --git a/setup.cfg b/setup.cfg index b24271f..60eab3b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,14 @@ docstring-quotes = """ multiline-quotes = """ # enable-flake8-bugbear +[isort] +known_firstparty=pytools,pyopencl,pymbolic,islpy,loopy,pytato +known_local_folder=arraycontext +line_length = 85 +lines_after_imports = 2 +combine_as_imports = True +multi_line_output = 4 + [mypy] # it reads pytato code, and pytato is 3.8+ python_version = 3.8 diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index f373b77..ebc8dc4 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -20,6 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import logging from dataclasses import dataclass from typing import Union @@ -28,26 +29,16 @@ import pytest from pytools.obj_array import make_obj_array -from arraycontext import ( - ArrayContext, - dataclass_array_container, with_container_arithmetic, - serialize_container, deserialize_container, with_array_context, - FirstAxisIsElementsTag, - PyOpenCLArrayContext, - PytatoPyOpenCLArrayContext, - EagerJAXArrayContext, - ArrayContainer, - tag_axes) from arraycontext import ( # noqa: F401 - pytest_generate_tests_for_array_contexts, - ) -from arraycontext.pytest import (_PytestPyOpenCLArrayContextFactoryWithClass, - _PytestPytatoPyOpenCLArrayContextFactory, - _PytestEagerJaxArrayContextFactory, - _PytestPytatoJaxArrayContextFactory) + ArrayContainer, ArrayContext, EagerJAXArrayContext, FirstAxisIsElementsTag, + PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, + deserialize_container, pytest_generate_tests_for_array_contexts, + serialize_container, tag_axes, with_array_context, with_container_arithmetic) +from arraycontext.pytest import ( + _PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, + _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory) -import logging logger = logging.getLogger(__name__) @@ -497,8 +488,9 @@ def test_dof_array_arithmetic_same_as_numpy(actx_factory): return ary.imag import operator + from random import randrange, uniform + from pytools import generate_nonnegative_integer_tuples_below as gnitb - from random import uniform, randrange for op_func, n_args, use_integers in [ (operator.add, 2, False), (operator.sub, 2, False), @@ -775,9 +767,8 @@ def test_container_map_on_device_scalar(actx_factory): arys += (np.pi,) from arraycontext import ( - map_array_container, rec_map_array_container, - map_reduce_array_container, rec_map_reduce_array_container, - ) + map_array_container, map_reduce_array_container, rec_map_array_container, + rec_map_reduce_array_container) for size, ary in zip(expected_sizes, arys[:-1]): result = map_array_container(lambda x: x, ary) @@ -921,6 +912,7 @@ def test_container_arithmetic(actx_factory): assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol from functools import partial + from arraycontext import rec_multimap_array_container for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: rec_multimap_array_container( @@ -973,8 +965,7 @@ def test_container_freeze_thaw(actx_factory): # {{{ check from arraycontext import ( - get_container_context_opt, - get_container_context_recursively_opt) + get_container_context_opt, get_container_context_recursively_opt) assert get_container_context_opt(ary_of_dofs) is None assert get_container_context_opt(mat_of_dofs) is None @@ -1077,7 +1068,7 @@ def test_flatten_array_container(actx_factory, shapes): def _checked_flatten(ary, actx, leaf_class=None): - from arraycontext import flatten, flat_size_and_dtype + from arraycontext import flat_size_and_dtype, flatten result = flatten(ary, actx, leaf_class=leaf_class) if leaf_class is None: @@ -1279,7 +1270,7 @@ def test_actx_compile_kwargs(actx_factory): def test_actx_compile_with_tuple_output_keys(actx_factory): # arraycontext.git<=3c9aee68 would fail due to a bug in output # key stringification logic. - from arraycontext import (to_numpy, from_numpy) + from arraycontext import from_numpy, to_numpy actx = actx_factory() def my_rhs(scale, vel): @@ -1361,6 +1352,7 @@ def test_leaf_array_type_broadcasting(actx_factory): return True else: import pyopencl as cl + # See https://github.com/inducer/pyopencl/issues/498 return cl.version.VERSION > (2021, 2, 5) @@ -1573,6 +1565,7 @@ def test_tagging(actx_factory): def test_compile_anonymous_function(actx_factory): from functools import partial + # See https://github.com/inducer/grudge/issues/287 actx = actx_factory() f = actx.compile(lambda x: 2*x+40) diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index fd3c6ea..eea1144 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -22,13 +22,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from arraycontext import PytatoPyOpenCLArrayContext -from arraycontext import pytest_generate_tests_for_array_contexts -from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory -from pytools.tag import Tag +import logging import pytest -import logging + +from pytools.tag import Tag + +from arraycontext import ( + PytatoPyOpenCLArrayContext, pytest_generate_tests_for_array_contexts) +from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory + + logger = logging.getLogger(__name__) diff --git a/test/test_utils.py b/test/test_utils.py index 5edf66b..94f6b0a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -22,11 +22,12 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import pytest +import logging import numpy as np +import pytest + -import logging logger = logging.getLogger(__name__) @@ -47,8 +48,9 @@ def test_pt_actx_key_stringification_uniqueness(): # {{{ test_dataclass_array_container def test_dataclass_array_container() -> None: - from typing import Optional from dataclasses import dataclass, field + from typing import Optional + from arraycontext import dataclass_array_container # {{{ string fields @@ -111,10 +113,9 @@ def test_dataclass_array_container() -> None: def test_dataclass_container_unions() -> None: from dataclasses import dataclass - from arraycontext import dataclass_array_container - from typing import Union - from arraycontext import Array + + from arraycontext import Array, dataclass_array_container # {{{ union fields -- GitLab