diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index be76e71d508d4daa16cd7fff9d78ff3f2078727f..3805db2910ab5ca65f402ef26fecfcc72a35e7ee 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -187,6 +187,49 @@ class LazilyCompilingFunctionCaller: program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]", "CompiledFunction"] = field(default_factory=lambda: {}) + def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays): + from pytato.target.loopy import BoundPyOpenCLProgram + + import loopy as lp + + with ProcessLogger(logger, "transform_dag"): + pt_dict_of_named_arrays = self.actx.transform_dag( + pt.make_dict_of_named_arrays(dict_of_named_arrays)) + + with ProcessLogger(logger, "generate_loopy"): + pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, + options=lp.Options( + return_dict=True, + no_numpy=True), + cl_device=self.actx.queue.device) + assert isinstance(pytato_program, BoundPyOpenCLProgram) + + with ProcessLogger(logger, "transform_loopy_program"): + + pytato_program = (pytato_program + .with_transformed_program( + lambda x: x.with_kernel( + x.default_entrypoint + .tagged(FromArrayContextCompile())))) + + pytato_program = (pytato_program + .with_transformed_program(self + .actx + .transform_loopy_program)) + + return pytato_program + + def _dag_to_compiled_func(self, dict_of_named_arrays, + input_id_to_name_in_program, output_id_to_name_in_program, + output_template): + pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays) + + return CompiledFunction( + self.actx, pytato_program, + input_id_to_name_in_program=input_id_to_name_in_program, + output_id_to_name_in_program=output_id_to_name_in_program, + output_template=output_template) + def __call__(self, *args: Any, **kwargs: Any) -> Any: """ Returns the result of :attr:`~LazilyCompilingFunctionCaller.f`'s @@ -197,8 +240,6 @@ class LazilyCompilingFunctionCaller: :attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. The intermediary pytato DAG for *args* is memoized in *self*. """ - from pytato.target.loopy import BoundPyOpenCLProgram - arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( args, kwargs) @@ -210,74 +251,70 @@ class LazilyCompilingFunctionCaller: return compiled_f(arg_id_to_arg) dict_of_named_arrays = {} - # output_naming_map: result id to name of the named array in the - # generated pytato DAG. - output_naming_map = {} - # input_naming_map: argument id to placeholder name in the generated - # pytato DAG. - input_naming_map = { + output_id_to_name_in_program = {} + input_id_to_name_in_program = { arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" for arg_id in arg_id_to_arg} - outputs = self.f(*[_get_f_placeholder_args(arg, iarg, input_naming_map) - for iarg, arg in enumerate(args)], - **{kw: _get_f_placeholder_args(arg, kw, input_naming_map) - for kw, arg in kwargs.items()}) + output_template = self.f( + *[_get_f_placeholder_args(arg, iarg, input_id_to_name_in_program) + for iarg, arg in enumerate(args)], + **{kw: _get_f_placeholder_args(arg, kw, input_id_to_name_in_program) + for kw, arg in kwargs.items()}) - if not is_array_container_type(outputs.__class__): + if not is_array_container_type(output_template.__class__): # TODO: We could possibly just short-circuit this interface if the # returned type is a scalar. Not sure if it's worth it though. raise NotImplementedError( f"Function '{self.f.__name__}' to be compiled " "did not return an array container, but an instance of " - f"'{outputs.__class__}' instead.") + f"'{output_template.__class__}' instead.") def _as_dict_of_named_arrays(keys, ary): name = "_pt_out_" + "_".join(str(key) for key in keys) - output_naming_map[keys] = name + output_id_to_name_in_program[keys] = name dict_of_named_arrays[name] = ary return ary rec_keyed_map_array_container(_as_dict_of_named_arrays, - outputs) + output_template) - import loopy as lp + from pytato import DictOfNamedArrays + compiled_func = self._dag_to_compiled_func( + DictOfNamedArrays(dict_of_named_arrays), + input_id_to_name_in_program=input_id_to_name_in_program, + output_id_to_name_in_program=output_id_to_name_in_program, + output_template=output_template) - with ProcessLogger(logger, "transform_dag"): - pt_dict_of_named_arrays = self.actx.transform_dag( - pt.make_dict_of_named_arrays(dict_of_named_arrays)) + self.program_cache[arg_id_to_descr] = compiled_func + return compiled_func(arg_id_to_arg) - with ProcessLogger(logger, "generate_loopy"): - pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, - options=lp.Options( - return_dict=True, - no_numpy=True), - cl_device=self.actx.queue.device) - assert isinstance(pytato_program, BoundPyOpenCLProgram) - with ProcessLogger(logger, "transform_loopy_program"): +def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): + input_kwargs_for_loopy = {} - pytato_program = (pytato_program - .with_transformed_program( - lambda x: x.with_kernel( - x.default_entrypoint - .tagged(FromArrayContextCompile())))) - - pytato_program = (pytato_program - .with_transformed_program(self - .actx - .transform_loopy_program)) + for arg_id, arg in arg_id_to_arg.items(): + if np.isscalar(arg): + arg = cla.to_device(actx.queue, np.array(arg)) + elif isinstance(arg, pt.array.DataWrapper): + # got a Datwwrapper => simply gets its data + arg = arg.data + elif isinstance(arg, cla.Array): + # got a frozen array => do nothing + pass + elif isinstance(arg, pt.Array): + # got an array expression => evaluate it + arg = actx.freeze(arg).with_queue(actx.queue) + else: + raise NotImplementedError(type(arg)) - self.program_cache[arg_id_to_descr] = CompiledFunction( - self.actx, pytato_program, - input_naming_map, output_naming_map, - output_template=outputs) + input_kwargs_for_loopy[input_id_to_name_in_program[arg_id]] = arg - return self.program_cache[arg_id_to_descr](arg_id_to_arg) + return input_kwargs_for_loopy -@dataclass +@dataclass(frozen=True) class CompiledFunction: """ A callable which captures the :class:`pytato.target.BoundProgram` resulting @@ -319,40 +356,18 @@ class CompiledFunction: :attr:`CompiledFunction.input_id_to_name_in_program` for input id's representation. """ - from arraycontext.container.traversal import rec_keyed_map_array_container - - input_kwargs_to_loopy = {} - - # {{{ preprocess args to get arguments (CL buffers) to be fed to the - # loopy program - - for arg_id, arg in arg_id_to_arg.items(): - if np.isscalar(arg): - arg = cla.to_device(self.actx.queue, np.array(arg)) - elif isinstance(arg, pt.array.DataWrapper): - # got a Datwwrapper => simply gets its data - arg = arg.data - elif isinstance(arg, cla.Array): - # got a frozen array => do nothing - pass - elif isinstance(arg, pt.Array): - # got an array expression => evaluate it - arg = self.actx.freeze(arg).with_queue(self.actx.queue) - else: - raise NotImplementedError(type(arg)) - - input_kwargs_to_loopy[self.input_id_to_name_in_program[arg_id]] = arg + input_kwargs_for_loopy = _args_to_cl_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) evt, out_dict = self.pytato_program(queue=self.actx.queue, allocator=self.actx.allocator, - **input_kwargs_to_loopy) + **input_kwargs_for_loopy) + # FIXME Kernels (for now) allocate tons of memory in temporaries. If we # race too far ahead with enqueuing, there is a distinct risk of # running out of memory. This mitigates that risk a bit, for now. evt.wait() - # }}} - def to_output_template(keys, _): return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]])