diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index afbe7ce94a231470c6a92ad31058feedd59d5789..ec13738acbdb90f1983639e9b13b9a66fc46ccb5 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -54,15 +54,21 @@ from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLi from arraycontext.container.traversal import (rec_map_array_container, with_array_context) from arraycontext.metadata import NameHint +from pytools import memoize_method if TYPE_CHECKING: import pytato import pyopencl as cl + import loopy as lp if getattr(sys, "_BUILDING_SPHINX_DOCS", False): import pyopencl as cl # noqa: F811 +import logging +logger = logging.getLogger(__name__) + + # {{{ tag conversion def _preprocess_array_tags(tags: ToTagSetConvertible) -> FrozenSet[Tag]: @@ -203,6 +209,9 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): def permits_advanced_indexing(self): return True + def get_target(self): + return None + # }}} # }}} @@ -210,6 +219,20 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC): # {{{ PytatoPyOpenCLArrayContext +from pytato.target.loopy import LoopyPyOpenCLTarget + + +class _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget): + def __init__(self, limit_arg_size_nbytes: int) -> None: + super().__init__() + self.limit_arg_size_nbytes = limit_arg_size_nbytes + + @memoize_method + def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]: + from loopy import PyOpenCLTarget + return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes) + + class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): """ A :class:`ArrayContext` that uses :mod:`pytato` data types to represent @@ -232,7 +255,11 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): """ def __init__( self, queue: "cl.CommandQueue", allocator=None, *, - compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None + use_memory_pool: Optional[bool] = None, + compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None, + + # do not use: only for testing + _force_svm_arg_limit: Optional[int] = None, ) -> None: """ :arg compile_trace_callback: A function of three arguments @@ -242,16 +269,57 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): representation. This interface should be considered unstable. """ + if allocator is not None and use_memory_pool is not None: + raise TypeError("may not specify both allocator and use_memory_pool") + + self.using_svm = None + + if allocator is None: + from pyopencl.characterize import has_coarse_grain_buffer_svm + has_svm = has_coarse_grain_buffer_svm(queue.device) + if has_svm: + self.using_svm = True + + from pyopencl.tools import SVMAllocator + allocator = SVMAllocator(queue.context, queue=queue) + + if use_memory_pool: + from pyopencl.tools import SVMPool + allocator = SVMPool(allocator) + else: + self.using_svm = False + + from pyopencl.tools import ImmediateAllocator + allocator = ImmediateAllocator(queue.context) + + if use_memory_pool: + from pyopencl.tools import MemoryPool + allocator = MemoryPool(allocator) + else: + # Check whether the passed allocator allocates SVM + try: + from pyopencl import SVMPointer + mem = allocator(4) + if isinstance(mem, SVMPointer): + self.using_svm = True + else: + self.using_svm = False + except ImportError: + self.using_svm = False + import pytato as pt import pyopencl.array as cla super().__init__(compile_trace_callback=compile_trace_callback) self.queue = queue + self.allocator = allocator self.array_types = (pt.Array, cla.Array) # unused, but necessary to keep the context alive self.context = self.queue.context + self._force_svm_arg_limit = _force_svm_arg_limit + @property def _frozen_array_types(self) -> Tuple[Type, ...]: import pyopencl.array as cla @@ -321,6 +389,29 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): self._rec_map_container(_to_numpy, self.freeze(array)), actx=None) + @memoize_method + def get_target(self): + import pyopencl as cl + import pyopencl.characterize as cl_char + + dev = self.queue.device + + if ( + self._force_svm_arg_limit is not None + or ( + self.using_svm and dev.type & cl.device_type.GPU + and cl_char.has_coarse_grain_buffer_svm(dev))): + + limit = dev.max_parameter_size + if self._force_svm_arg_limit is not None: + limit = self._force_svm_arg_limit + + logger.info(f"limiting argument buffer size for {dev} to {limit} bytes") + + return _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit) + else: + return super().get_target() + def freeze(self, array): if np.isscalar(array): return array @@ -415,7 +506,8 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): pt_prg = pt.generate_loopy(transformed_dag, options=_DEFAULT_LOOPY_OPTIONS, cl_device=self.queue.device, - function_name=function_name) + function_name=function_name, + target=self.get_target()) pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) self._freeze_prg_cache[normalized_expr] = pt_prg else: diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 07cb57b9377db62e1a408db3b618c137b2514094..ac4c01d7ce6ed4e842583acb1c543a97417e73d0 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -420,7 +420,9 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): options=lp.Options( return_dict=True, no_numpy=True), - function_name=_prg_id_to_kernel_name(prg_id)) + function_name=_prg_id_to_kernel_name(prg_id), + target=self.actx.get_target(), + ) assert isinstance(pytato_program, BoundPyOpenCLProgram) self.actx._compile_trace_callback( diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index f4d132ca8223e80198bd774ba9fc822fe824ff08..7dc76c4c9f78ce128044eecf2502e9cdd7a7386c 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -100,6 +100,27 @@ def test_tags_preserved_after_freeze(actx_factory): assert foo.axes[1].tags_of_type(BazTag) +def test_arg_size_limit(actx_factory): + ran_callback = False + + def my_ctc(what, stage, ir): + if stage == "final": + assert ir.target.limit_arg_size_nbytes == 42 + nonlocal ran_callback + ran_callback = True + + def twice(x): + return 2 * x + + actx = _PytatoPyOpenCLArrayContextForTests( + actx_factory().queue, compile_trace_callback=my_ctc, _force_svm_arg_limit=42) + + f = actx.compile(twice) + f(99) + + assert ran_callback + + if __name__ == "__main__": import sys if len(sys.argv) > 1: