From d7ad4c28cf8dec40d080ba127e087f2234351cc1 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Wed, 26 May 2021 17:16:31 -0500 Subject: [PATCH] add pytato array context from meshmode --- arraycontext/impl/pytato.py | 371 ++++++++++++++++++++++++++++++++++++ doc/array_context.rst | 6 + 2 files changed, 377 insertions(+) create mode 100644 arraycontext/impl/pytato.py diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py new file mode 100644 index 0000000..201511c --- /dev/null +++ b/arraycontext/impl/pytato.py @@ -0,0 +1,371 @@ +""" +.. currentmodule:: arraycontext +.. autoclass:: PytatoArrayContext +""" +__copyright__ = """ +Copyright (C) 2020-1 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +from arraycontext.fake_numpy import \ + BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace +from arraycontext.context import ArrayContext +import numpy as np +from typing import Any, Callable, Tuple, Union, Number +import loopy as lp + + +class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + def norm(self, array, ord=None): + raise NotImplementedError + + +class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): + def _get_fake_numpy_linalg_namespace(self): + return _PytatoFakeNumpyLinalgNamespace(self._array_context) + + @property + def ns(self): + return self._array_context.ns + + def exp(self, x): + import pytato as pt + from meshmode.dof_array import obj_or_dof_array_vectorize + return obj_or_dof_array_vectorize(pt.exp, x) + + def reshape(self, a, newshape): + import pytato as pt + + from meshmode.dof_array import obj_or_dof_array_vectorize_n_args + return obj_or_dof_array_vectorize_n_args(pt.reshape, a, newshape) + + def transpose(self, a, axes=None): + import pytato as pt + + from meshmode.dof_array import obj_or_dof_array_vectorize_n_args + return obj_or_dof_array_vectorize_n_args(pt.transpose, a, axes) + + def concatenate(self, arrays, axis=0): + import pytato as pt + from meshmode.dof_array import obj_or_dof_array_vectorize_n_args + return obj_or_dof_array_vectorize_n_args(pt.concatenate, arrays, axis) + + def maximum(self, x, y): + import pytato as pt + from meshmode.dof_array import obj_or_dof_array_vectorize_n_args + return obj_or_dof_array_vectorize_n_args(pt.maximum, x, y) + + def minimum(self, x, y): + import pytato as pt + from meshmode.dof_array import obj_or_dof_array_vectorize_n_args + return obj_or_dof_array_vectorize_n_args(pt.minimum, x, y) + + def sum(self, a, dtype=None): + import pytato as pt + if dtype not in [a.dtype, None]: + raise NotImplementedError + return pt.sum(a) + + def min(self, a): + import pytato as pt + return pt.amin(a) + + def max(self, a): + import pytato as pt + return pt.amax(a) + + +class PytatoCompiledOperator: + def __init__(self, actx, pytato_program, input_spec, output_spec): + self.actx = actx + self.pytato_program = pytato_program + self.input_spec = input_spec + self.output_spec = output_spec + + def __call__(self, *args): + import pytato as pt + import pyopencl.array as cla + from meshmode.dof_array import DOFArray + from pytools.obj_array import flat_obj_array + + updated_kwargs = {} + + def from_obj_array_to_input_dict(array, pos): + input_dict = {} + for i in range(len(self.input_spec[pos])): + for j in range(self.input_spec[pos][i]): + ary = array[i][j] + arg_name = f"_msh_inp_{pos}_{i}_{j}" + if arg_name not in ( + self.pytato_program.program["_pt_kernel"].arg_dict): + continue + if isinstance(ary, pt.array.DataWrapper): + input_dict[arg_name] = ary.data + elif isinstance(ary, cla.Array): + input_dict[arg_name] = ary + elif isinstance(ary, pt.Array): + input_dict[arg_name] = self.actx.freeze( + ary).with_queue(self.actx.queue) + else: + raise TypeError("Expect pt.DataWrapper or CL-array, got " + f"{type(ary)}") + + return input_dict + + def from_return_dict_to_obj_array(return_dict): + return flat_obj_array([DOFArray.from_list(self.actx, + [self.actx.thaw(return_dict[f"_msh_out_{i}_{j}"]) + for j in range(self.output_spec[i])]) + for i in range(len(self.output_spec))]) + + for iarg, arg in enumerate(args): + if isinstance(arg, np.number): + arg_name = f"_msh_inp_{iarg}" + if arg_name not in ( + self.pytato_program.program["_pt_kernel"].arg_dict): + continue + + updated_kwargs[arg_name] = cla.to_device(self.actx.queue, + np.array(arg)) + elif isinstance(arg, np.ndarray) and all(isinstance(el, DOFArray) + for el in arg): + updated_kwargs.update(from_obj_array_to_input_dict(arg, iarg)) + else: + raise NotImplementedError("PytatoCompiledOperator cannot handle" + f" '{type(arg)}'s") + + evt, out_dict = self.pytato_program(queue=self.actx.queue, + allocator=self.actx.allocator, + **updated_kwargs) + evt.wait() + + return from_return_dict_to_obj_array(out_dict) + + +class PytatoArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :mod:`pytato` data types to represent + the DOF arrays targeting OpenCL for offloading operations. + + .. attribute:: context + + A :class:`pyopencl.Context`. + + .. attribute:: queue + + A :class:`pyopencl.CommandQueue`. + """ + import pytato as pt + _array_type_ = pt.Array + + def __init__(self, queue, allocator=None): + super().__init__() + self.queue = queue + self.allocator = allocator + self.np = self._get_fake_numpy_namespace() + + def _get_fake_numpy_namespace(self): + return _PytatoFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def empty(self, shape, dtype): + raise ValueError("PytatoArrayContext does not support empty") + + def symbolic_array_var(self, shape, dtype, name=None): + import pytato as pt + return pt.make_placeholder(shape=shape, dtype=dtype, name=name) + + def zeros(self, shape, dtype): + import pytato as pt + return pt.zeros(shape, dtype) + + def from_numpy(self, np_array: np.ndarray): + import pytato as pt + import pyopencl.array as cla + cl_array = cla.to_device(self.queue, np_array) + return pt.make_data_wrapper(cl_array) + + def to_numpy(self, array): + cl_array = self.freeze(array) + return cl_array.get(queue=self.queue) + + def call_loopy(self, program, **kwargs): + from pytato.loopy import call_loopy + import pyopencl.array as cla + entrypoint, = set(program.callables_table) + + # thaw frozen arrays + kwargs = {kw: (self.thaw(arg) if isinstance(arg, cla.Array) else arg) + for kw, arg in kwargs.items()} + + return call_loopy(program, kwargs, entrypoint) + + def freeze(self, array): + import pytato as pt + import pyopencl.array as cla + + if isinstance(array, pt.Placeholder): + raise ValueError("freezing placeholder would return garbage valued" + " arrays") + if isinstance(array, cla.Array): + return array.with_queue(None) + if not isinstance(array, pt.Array): + raise TypeError("PytatoArrayContext.freeze invoked with non-pt arrays") + + prg = pt.generate_loopy(array, cl_device=self.queue.device) + evt, (cl_array,) = prg(self.queue) + evt.wait() + + return cl_array.with_queue(None) + + def thaw(self, array): + import pytato as pt + import pyopencl.array as cla + + if not isinstance(array, cla.Array): + raise TypeError("PytatoArrayContext.thaw expects CL arrays, got " + f"{type(array)}") + + return pt.make_data_wrapper(array.with_queue(self.queue)) + + # }}} + + def compile(self, f: Callable[[Any], Any], + inputs_like: Tuple[Union[Number, np.array], ...]) -> Callable[..., Any]: + from pytools.obj_array import flat_obj_array + from meshmode.dof_array import DOFArray + import pytato as pt + + def make_placeholder_like(input_like, pos): + if isinstance(input_like, np.number): + return pt.make_placeholder(input_like.dtype, + f"_msh_inp_{pos}") + elif isinstance(input_like, np.ndarray) and all(isinstance(e, DOFArray) + for e in input_like): + return flat_obj_array([DOFArray.from_list(self, + [pt.make_placeholder(grp_ary.shape, + grp_ary.dtype, f"_msh_inp_{pos}_{i}_{j}") + for j, grp_ary in enumerate(dof_ary)]) + for i, dof_ary in enumerate(input_like)]) + + raise NotImplementedError(f"Unknown input type '{type(input_like)}'.") + + def as_dict_of_named_arrays(fields_obj_ary): + dict_of_named_arrays = {} + # output_spec: a list of length #fields; ith-entry denotes #groups in + # ith-field + output_spec = [] + for i, field in enumerate(fields_obj_ary): + output_spec.append(len(field)) + for j, grp in enumerate(field): + dict_of_named_arrays[f"_msh_out_{i}_{j}"] = grp + + return pt.make_dict_of_named_arrays(dict_of_named_arrays), output_spec + + outputs = f(*[make_placeholder_like(el, iel) + for iel, el in enumerate(inputs_like)]) + + if not (isinstance(outputs, np.ndarray) + and all(isinstance(e, DOFArray) + for e in outputs)): + raise TypeError("Can only pass in functions that return numpy" + " array of DOFArrays.") + + output_dict_of_named_arrays, output_spec = as_dict_of_named_arrays(outputs) + + pytato_program = pt.generate_loopy(output_dict_of_named_arrays, + options={"return_dict": True}, + cl_device=self.queue.device) + + if False: + from time import time + start = time() + # transforming leads to compile-time slow downs (turning off for now) + pytato_program.program = self.transform_loopy_program( + pytato_program.program) + end = time() + print(f"Transforming took {end-start} secs") + + return PytatoCompiledOperator(self, pytato_program, + [[len(arg) for arg in input_like] + if isinstance(input_like, np.ndarray) + else [] + + for input_like in inputs_like], + output_spec) + + def transform_loopy_program(self, prg): + from loopy.program import iterate_over_kernels_if_given_program + + nwg = 48 + nwi = (16, 2) + + @iterate_over_kernels_if_given_program + def gridify(knl): + # {{{ Pattern matching inames + + for insn in knl.instructions: + if isinstance(insn, lp.CallInstruction): + # must be a callable kernel, don't touch. + pass + elif isinstance(insn, lp.Assignment): + bigger_loop = None + smaller_loop = None + for iname in insn.within_inames: + if iname.startswith("iel"): + assert bigger_loop is None + bigger_loop = iname + if iname.startswith("idof"): + assert smaller_loop is None + smaller_loop = iname + + if bigger_loop or smaller_loop: + assert bigger_loop is not None and smaller_loop is not None + else: + sorted_inames = sorted(tuple(insn.within_inames), + key=knl.get_constant_iname_length) + smaller_loop = sorted_inames[0] + bigger_loop = sorted_inames[1] + + knl = lp.chunk_iname(knl, bigger_loop, nwg, + outer_tag="g.0") + knl = lp.split_iname(knl, f"{bigger_loop}_inner", + nwi[0], inner_tag="l.1") + knl = lp.split_iname(knl, smaller_loop, + nwi[1], inner_tag="l.0") + elif isinstance(insn, lp.BarrierInstruction): + pass + else: + raise NotImplementedError + + # }}} + + return knl + + prg = lp.set_options(prg, "insert_additional_gbarriers") + + return gridify(prg) + + +# }}} diff --git a/doc/array_context.rst b/doc/array_context.rst index 4b3b004..aec9fed 100644 --- a/doc/array_context.rst +++ b/doc/array_context.rst @@ -12,3 +12,9 @@ Array context based on :mod:`pyopencl.array` -------------------------------------------- .. automodule:: arraycontext.impl.pyopencl + + +Array context based on :mod:`pytato` +-------------------------------------------- + +.. automodule:: arraycontext.impl.pytato -- GitLab