From 85e4285d7aab463aa18d6152238103ec5d6eabbc Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 26 Mar 2023 09:09:37 -0500 Subject: [PATCH] pytato: avoid codegen when possible --- arraycontext/impl/pytato/__init__.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8fe3055..3ad3d70 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -486,6 +486,16 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): # }}} + def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: + key_str = "_ary" + _ary_container_key_stringifier(key) + return key_to_frozen_subary[key_str] + + if not key_to_pt_arrays: + # all cl arrays => no need to perform any codegen + return with_array_context( + rec_keyed_map_array_container(_to_frozen, array), + actx=None) + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( key_to_pt_arrays) normalized_expr, bound_arguments = _normalize_pt_expr( @@ -544,10 +554,6 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: - key_str = "_ary" + _ary_container_key_stringifier(key) - return key_to_frozen_subary[key_str] - return with_array_context( rec_keyed_map_array_container(_to_frozen, array), actx=None) @@ -800,6 +806,16 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): # }}} + def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray: + key_str = "_ary" + _ary_container_key_stringifier(key) + return key_to_frozen_subary[key_str] + + if not key_to_pt_arrays: + # all cl arrays => no need to perform any codegen + return with_array_context( + rec_keyed_map_array_container(_to_frozen, array), + actx=None) + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays) transformed_dag = self.transform_dag(pt_dict_of_named_arrays) pt_prg = pt.generate_jax(transformed_dag, jit=True) @@ -812,10 +828,6 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray: - key_str = "_ary" + _ary_container_key_stringifier(key) - return key_to_frozen_subary[key_str] - return with_array_context( rec_keyed_map_array_container(_to_frozen, array), actx=None) -- GitLab