diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index a49a81e70237c231c2644c24f845442075e17b33..a7cea85dd0c18454b24d4df7a310c9ad9a7f8c84 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -152,71 +152,63 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): class PytatoCompiledOperator: - def __init__(self, actx, pytato_program, input_spec, output_spec): + def __init__(self, actx, pytato_program, input_id_to_name_in_program, + output_id_to_name_in_program, output_template): self.actx = actx self.pytato_program = pytato_program - self.input_spec = input_spec - self.output_spec = output_spec + self.input_id_to_name_in_program = input_id_to_name_in_program + self.output_id_to_name_in_program = output_id_to_name_in_program + self.output_template = output_template def __call__(self, *args): import pytato as pt import pyopencl.array as cla - from arraycontext.impl import _is_meshmode_dofarray - from pytools.obj_array import flat_obj_array - - updated_kwargs = {} - - def from_obj_array_to_input_dict(array, pos): - input_dict = {} - for i in range(len(self.input_spec[pos])): - for j in range(self.input_spec[pos][i]): - ary = array[i][j] - arg_name = f"_msh_inp_{pos}_{i}_{j}" - if arg_name not in ( - self.pytato_program.program["_pt_kernel"].arg_dict): - continue + from arraycontext import (is_array_container, + rec_keyed_map_array_container) + + input_kwargs_to_loopy = {} + + # {{{ extract loopy arguments execute the program + + for pos, arg in enumerate(args): + if isinstance(arg, np.number): + input_kwargs_to_loopy[self.input_id_to_name_in_program[pos]] = ( + arg) + elif is_array_container(arg): + def _extract_lpy_kwargs(keys, ary): if isinstance(ary, pt.array.DataWrapper): - input_dict[arg_name] = ary.data + processed_ary = ary.data elif isinstance(ary, cla.Array): - input_dict[arg_name] = ary + processed_ary = ary elif isinstance(ary, pt.Array): - input_dict[arg_name] = self.actx.freeze( - ary).with_queue(self.actx.queue) + processed_ary = (self.actx.freeze(ary) + .with_queue(self.actx.queue)) else: - raise TypeError("Expect pt.DataWrapper or CL-array, got " + raise TypeError("Expect pt.Array or CL-array, got " f"{type(ary)}") - return input_dict + input_kwargs_to_loopy[ + self.input_id_to_name_in_program[(pos,) + + keys]] = processed_ary + return ary - def from_return_dict_to_obj_array(return_dict): - from meshmode.dof_array import DOFArray # pylint: disable=import-error - return flat_obj_array([DOFArray.from_list(self.actx, - [self.actx.thaw(return_dict[f"_msh_out_{i}_{j}"]) - for j in range(self.output_spec[i])]) - for i in range(len(self.output_spec))]) - - for iarg, arg in enumerate(args): - if isinstance(arg, np.number): - arg_name = f"_msh_inp_{iarg}" - if arg_name not in ( - self.pytato_program.program["_pt_kernel"].arg_dict): - continue - - updated_kwargs[arg_name] = cla.to_device(self.actx.queue, - np.array(arg)) - elif isinstance(arg, np.ndarray) and all(_is_meshmode_dofarray(el) - for el in arg): - updated_kwargs.update(from_obj_array_to_input_dict(arg, iarg)) + rec_keyed_map_array_container(_extract_lpy_kwargs, arg) else: - raise NotImplementedError("PytatoCompiledOperator cannot handle" - f" '{type(arg)}'s") + raise NotImplementedError(type(arg)) evt, out_dict = self.pytato_program(queue=self.actx.queue, allocator=self.actx.allocator, - **updated_kwargs) + **input_kwargs_to_loopy) + evt.wait() - return from_return_dict_to_obj_array(out_dict) + # }}} + + def to_output_template(keys, _): + return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]]) + + return rec_keyed_map_array_container(to_output_template, + self.output_template) class PytatoArrayContext(ArrayContext): @@ -308,70 +300,66 @@ class PytatoArrayContext(ArrayContext): # }}} def compile(self, f: Callable[[Any], Any], - inputs_like: Tuple[Union[Number, np.ndarray], ...]) -> Callable[ - ..., Any]: - from pytools.obj_array import flat_obj_array - from arraycontext.impl import _is_meshmode_dofarray - from meshmode.dof_array import DOFArray # pylint: disable=import-error + inputs_like: Tuple[Union[Number, np.ndarray], ...] + ) -> Callable[..., Any]: + from arraycontext import (rec_keyed_map_array_container, + is_array_container) import pytato as pt - def make_placeholder_like(input_like, pos): + dict_of_named_arrays = {} + output_naming_map = {} + input_naming_map = {} + + def to_placeholder(input_like, pos): if isinstance(input_like, np.number): - return pt.make_placeholder((), input_like.dtype, - f"_msh_inp_{pos}") - elif isinstance(input_like, np.ndarray) and all(_is_meshmode_dofarray(e) - for e in input_like): - return flat_obj_array([DOFArray.from_list(self, - [pt.make_placeholder(grp_ary.shape, - grp_ary.dtype, f"_msh_inp_{pos}_{i}_{j}") - for j, grp_ary in enumerate(dof_ary)]) - for i, dof_ary in enumerate(input_like)]) - - raise NotImplementedError(f"Unknown input type '{type(input_like)}'.") - - def as_dict_of_named_arrays(fields_obj_ary): - dict_of_named_arrays = {} - # output_spec: a list of length #fields; ith-entry denotes #groups in - # ith-field - output_spec = [] - for i, field in enumerate(fields_obj_ary): - output_spec.append(len(field)) - for j, grp in enumerate(field): - dict_of_named_arrays[f"_msh_out_{i}_{j}"] = grp - - return pt.make_dict_of_named_arrays(dict_of_named_arrays), output_spec - - outputs = f(*[make_placeholder_like(el, iel) + name = f"_pt_in_{pos}" + input_naming_map[(pos, )] = name + return pt.make_placeholder((), input_like.dtype, name) + elif is_array_container(input_like): + def _rec_to_placeholder(keys, ary): + name = f"_pt_in_{pos}_" + "_".join(str(key) + for key in keys) + input_naming_map[(pos,) + keys] = name + return pt.make_placeholder(ary.shape, ary.dtype, + name) + return rec_keyed_map_array_container(_rec_to_placeholder, + input_like) + else: + raise NotImplementedError("Unknown input type " + f"'{type(input_like)}'.") + + outputs = f(*[to_placeholder(el, iel) for iel, el in enumerate(inputs_like)]) - if not (isinstance(outputs, np.ndarray) - and all(_is_meshmode_dofarray(e) - for e in outputs)): - raise TypeError("Can only pass in functions that return numpy" - " array of DOFArrays.") + if not is_array_container(outputs): + # TODO: We could possibly just short-circuit this interface if the + # returned type is a scalar. Not sure if it's worth it though. + raise ValueError("Function to be compiled did not return an array" + " container.") + + def _as_dict_of_named_arrays(keys, ary): + name = "_pt_out_" + "_".join(str(key) + for key in keys) + output_naming_map[keys] = name + dict_of_named_arrays[name] = ary + return ary - output_dict_of_named_arrays, output_spec = as_dict_of_named_arrays(outputs) + rec_keyed_map_array_container(_as_dict_of_named_arrays, + outputs) - pytato_program = pt.generate_loopy(output_dict_of_named_arrays, + pytato_program = pt.generate_loopy(dict_of_named_arrays, options={"return_dict": True}, cl_device=self.queue.device) if False: - from time import time - start = time() # transforming leads to compile-time slow downs (turning off for now) - pytato_program.program = self.transform_loopy_program( - pytato_program.program) - end = time() - print(f"Transforming took {end-start} secs") + pytato_program.program = self.transform_loopy_program(pytato_program + .program) return PytatoCompiledOperator(self, pytato_program, - [[len(arg) for arg in input_like] - if isinstance(input_like, np.ndarray) - else [] - - for input_like in inputs_like], - output_spec) + input_naming_map, + output_naming_map, + output_template=outputs) def transform_loopy_program(self, prg): from loopy.translation_unit import for_each_kernel