From f618fa403d23e9bea45a5631176af2ed28e73c37 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Tue, 22 Jun 2021 05:51:19 -0500 Subject: [PATCH] move function out of another function - closures are difficult to read --- arraycontext/impl/pytato/compile.py | 50 ++++++++++++++++------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index d83ece6..68a76ad 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -29,6 +29,9 @@ THE SOFTWARE. from arraycontext.container import ArrayContainer from arraycontext import PytatoPyOpenCLArrayContext +from arraycontext.container.traversal import (rec_keyed_map_array_container, + is_array_container) + import numpy as np from typing import Any, Callable, Tuple, Dict from dataclasses import dataclass, field @@ -98,6 +101,30 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: return "_".join(_rec_str(key) for key in keys) +def _to_arg_descr(arg: Any) -> AbstractInputDescriptor: + """ + Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. + Returns a :class:`AbstractInputDescriptor` for a + attr:`LazilyCompilingFunctionCaller.f`'s input argument. + """ + if np.isscalar(arg): + return ScalarInputDescriptor(np.dtype(arg)) + elif is_array_container(arg): + id_to_ary_descr = {} + + def id_collector(keys, ary): + id_to_ary_descr[keys] = LeafArrayDescriptor(np.dtype(ary.dtype), + ary.shape) + return ary + + rec_keyed_map_array_container(id_collector, arg) + return ArrayContainerInputDescriptor(pmap(id_to_ary_descr)) + else: + raise ValueError("Argument to a compiled operator should be" + " either a scalar or an array container. Got" + f" '{arg}'.") + + @dataclass class LazilyCompilingFunctionCaller: """ @@ -125,28 +152,7 @@ class LazilyCompilingFunctionCaller: with *args* in a lazy-sense. """ - from arraycontext.container.traversal import (rec_keyed_map_array_container, - is_array_container) - - def to_arg_descr(arg: Any) -> AbstractInputDescriptor: - if np.isscalar(arg): - return ScalarInputDescriptor(np.dtype(arg)) - elif is_array_container(arg): - id_to_ary_descr = {} - - def id_collector(keys, ary): - id_to_ary_descr[keys] = LeafArrayDescriptor(np.dtype(ary.dtype), - ary.shape) - return ary - - rec_keyed_map_array_container(id_collector, arg) - return ArrayContainerInputDescriptor(pmap(id_to_ary_descr)) - else: - raise ValueError("Argument to a compiled operator should be" - " either a scalar or an array container. Got" - f" '{arg}'.") - - arg_descrs = tuple(to_arg_descr(arg) for arg in args) + arg_descrs = tuple(_to_arg_descr(arg) for arg in args) try: compiled_f = self.program_cache[arg_descrs] -- GitLab