Skip to content
Snippets Groups Projects
Commit 85e4285d authored by Kaushik Kulkarni's avatar Kaushik Kulkarni Committed by Andreas Klöckner
Browse files

pytato: avoid codegen when possible

parent 829f8929
No related branches found
No related tags found
No related merge requests found
Pipeline #422159 failed
......@@ -486,6 +486,16 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
# }}}
def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray:
key_str = "_ary" + _ary_container_key_stringifier(key)
return key_to_frozen_subary[key_str]
if not key_to_pt_arrays:
# all cl arrays => no need to perform any codegen
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
key_to_pt_arrays)
normalized_expr, bound_arguments = _normalize_pt_expr(
......@@ -544,10 +554,6 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
for k, v in out_dict.items()}
}
def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray:
key_str = "_ary" + _ary_container_key_stringifier(key)
return key_to_frozen_subary[key_str]
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
......@@ -800,6 +806,16 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
# }}}
def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray:
key_str = "_ary" + _ary_container_key_stringifier(key)
return key_to_frozen_subary[key_str]
if not key_to_pt_arrays:
# all cl arrays => no need to perform any codegen
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays)
transformed_dag = self.transform_dag(pt_dict_of_named_arrays)
pt_prg = pt.generate_jax(transformed_dag, jit=True)
......@@ -812,10 +828,6 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
for k, v in out_dict.items()}
}
def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray:
key_str = "_ary" + _ary_container_key_stringifier(key)
return key_to_frozen_subary[key_str]
return with_array_context(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment