From 167ae7c6ee7c96118ed7308e3b9e74a8d199088f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 20 Sep 2022 15:10:38 -0500 Subject: [PATCH] Move ArgSizeLimitingPytatoLoopyPyOpenCLTarget to pytato.utils, remove hard pytato dep --- arraycontext/impl/pytato/__init__.py | 19 +++---------------- arraycontext/impl/pytato/utils.py | 25 ++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index ec13738..33c57bc 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -59,7 +59,6 @@ from pytools import memoize_method if TYPE_CHECKING: import pytato import pyopencl as cl - import loopy as lp if getattr(sys, "_BUILDING_SPHINX_DOCS", False): import pyopencl as cl # noqa: F811 @@ -219,20 +218,6 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): # {{{ PytatoPyOpenCLArrayContext -from pytato.target.loopy import LoopyPyOpenCLTarget - - -class _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget): - def __init__(self, limit_arg_size_nbytes: int) -> None: - super().__init__() - self.limit_arg_size_nbytes = limit_arg_size_nbytes - - @memoize_method - def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]: - from loopy import PyOpenCLTarget - return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes) - - class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): """ A :class:`ArrayContext` that uses :mod:`pytato` data types to represent @@ -408,7 +393,9 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): logger.info(f"limiting argument buffer size for {dev} to {limit} bytes") - return _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit) + from arraycontext.impl.pytato.utils import \ + ArgSizeLimitingPytatoLoopyPyOpenCLTarget + return ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit) else: return super().get_target() diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 2babd55..9d77202 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -23,12 +23,18 @@ THE SOFTWARE. """ -from typing import Any, Dict, Set, Tuple, Mapping +from typing import Any, Dict, Set, Tuple, Mapping, Optional, TYPE_CHECKING +from pytools import memoize_method + from pytato.array import SizeParam, Placeholder, make_placeholder, Axis as PtAxis from pytato.array import Array, DataWrapper, DictOfNamedArrays from pytato.transform import CopyMapper from pytools import UniqueNameGenerator from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis +from pytato.target.loopy import LoopyPyOpenCLTarget + +if TYPE_CHECKING: + import loopy as lp class _DatawrapperToBoundPlaceholderMapper(CopyMapper): @@ -91,3 +97,20 @@ def get_pt_axes_from_cl_axes(axes: Tuple[ClAxis, ...]) -> Tuple[PtAxis, ...]: def get_cl_axes_from_pt_axes(axes: Tuple[PtAxis, ...]) -> Tuple[ClAxis, ...]: return tuple(ClAxis(axis.tags) for axis in axes) + + +# {{{ arg-size-limiting loopy target + +class ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget): + def __init__(self, limit_arg_size_nbytes: int) -> None: + super().__init__() + self.limit_arg_size_nbytes = limit_arg_size_nbytes + + @memoize_method + def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]: + from loopy import PyOpenCLTarget + return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes) + +# }}} + +# vim: foldmethod=marker -- GitLab