diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 8e3a162efde37b8509e92f7be04c92eca707ad98..e9e7dc2918c0398bc577d5373e94864f131e205f 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -841,7 +841,8 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]], If *result* is a :class:`dict` or a :class:`pytato.DictOfNamedArrays` and *options* is not supplied, then the Loopy option - :attr:`~loopy.Options.return_dict` will be set to *True*. + :attr:`~loopy.Options.return_dict` will be set to *True*. If it is supplied, + :attr:`~loopy.Options.return_dict` must already be set to *True*. """ result_is_dict = isinstance(result, (dict, DictOfNamedArrays)) @@ -858,8 +859,18 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]], outputs = preproc_result.outputs compute_order = preproc_result.compute_order - if options is None and result_is_dict: - options = lp.Options(return_dict=True) + if options is None: + options = lp.Options(return_dict=result_is_dict) + elif isinstance(options, dict): + from warnings import warn + warn("Passing a dict for options is deprecated and will stop working in " + "2022. Pass an actual loopy.Options object instead.", + DeprecationWarning, stacklevel=2) + options = lp.Options(**options) + + if options.return_dict != result_is_dict: + raise ValueError("options.result_is_dict is expected to match " + "whether the returned value is a dictionary") state = get_initial_codegen_state(target, options) diff --git a/test/test_codegen.py b/test/test_codegen.py index 202958adbd3d86c814817630e049db21b97b1e91..46c4efbe6c08228088524af122bc7aba6df7c82a 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -135,13 +135,6 @@ def test_codegen_with_DictOfNamedArrays(ctx_factory): # noqa result = pt.DictOfNamedArrays(dict(x_out=x, y_out=y)) - # Without return_dict. - prog = pt.generate_loopy(result, cl_device=queue.device, - options=lp.Options(return_dict=False)) - _, (x_out, y_out) = prog(queue, x=x_in, y=y_in) - assert (x_out == x_in).all() - assert (y_out == y_in).all() - # With return_dict. prog = pt.generate_loopy(result, cl_device=queue.device)