diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8fe305571c1a20f02660ce38ece48060eec354bb..3ad3d70a16f3ac56d207cb73b271877702d58b98 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -486,6 +486,16 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): # }}} + def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: + key_str = "_ary" + _ary_container_key_stringifier(key) + return key_to_frozen_subary[key_str] + + if not key_to_pt_arrays: + # all cl arrays => no need to perform any codegen + return with_array_context( + rec_keyed_map_array_container(_to_frozen, array), + actx=None) + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( key_to_pt_arrays) normalized_expr, bound_arguments = _normalize_pt_expr( @@ -544,10 +554,6 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: - key_str = "_ary" + _ary_container_key_stringifier(key) - return key_to_frozen_subary[key_str] - return with_array_context( rec_keyed_map_array_container(_to_frozen, array), actx=None) @@ -800,6 +806,16 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): # }}} + def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray: + key_str = "_ary" + _ary_container_key_stringifier(key) + return key_to_frozen_subary[key_str] + + if not key_to_pt_arrays: + # all cl arrays => no need to perform any codegen + return with_array_context( + rec_keyed_map_array_container(_to_frozen, array), + actx=None) + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays) transformed_dag = self.transform_dag(pt_dict_of_named_arrays) pt_prg = pt.generate_jax(transformed_dag, jit=True) @@ -812,10 +828,6 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray: - key_str = "_ary" + _ary_container_key_stringifier(key) - return key_to_frozen_subary[key_str] - return with_array_context( rec_keyed_map_array_container(_to_frozen, array), actx=None)