Skip to content
Snippets Groups Projects
Commit 1255f4a0 authored by Xiaoyu Wei's avatar Xiaoyu Wei
Browse files

Merge branch 'backport-cachefixes' into 'main'

Reuse array contexts when interpolating

See merge request !31
parents 697fd16f 9adb0d43
No related branches found
No related tags found
1 merge request!31Reuse array contexts when interpolating
Pipeline #142337 passed with warnings with stage
in 17 minutes and 21 seconds
......@@ -542,7 +542,7 @@ def invert_affine_transform(mat_a, disp_b):
# {{{ from meshmode interpolation
def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup,
def interpolate_from_meshmode(actx, dof_vec, elements_to_sources_lookup,
order="tree"):
"""Interpolate a DoF vector from :mod:`meshmode`.
......@@ -565,6 +565,15 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup,
if not isinstance(dof_vec, cl.array.Array):
raise TypeError("non-array passed to interpolator")
if not isinstance(actx, PyOpenCLArrayContext):
if isinstance(actx, cl.CommandQueue):
from warnings import warn
warn("Command queue passed to the interpolator. "
"Supply an array context to enable proper caching.")
actx = PyOpenCLArrayContext(actx)
else:
raise ValueError
assert len(elements_to_sources_lookup.discr.groups) == 1
assert len(elements_to_sources_lookup.discr.mesh.groups) == 1
degroup = elements_to_sources_lookup.discr.groups[0]
......@@ -590,10 +599,10 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup,
# mapped source points.
sources_in_element_starts = \
elements_to_sources_lookup.sources_in_element_starts.get(queue)
elements_to_sources_lookup.sources_in_element_starts.get(actx.queue)
sources_in_element_lists = \
elements_to_sources_lookup.sources_in_element_lists.get(queue)
tree = elements_to_sources_lookup.tree.get(queue)
elements_to_sources_lookup.sources_in_element_lists.get(actx.queue)
tree = elements_to_sources_lookup.tree.get(actx.queue)
unit_sources_host = make_obj_array(
[np.zeros_like(srccrd) for srccrd in tree.sources])
......@@ -615,7 +624,7 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup,
ivmapped_el_sources[iaxis, :]
unit_sources = make_obj_array(
[cl.array.to_device(queue, usc) for usc in unit_sources_host])
[cl.array.to_device(actx.queue, usc) for usc in unit_sources_host])
# -----------------------------------------------------
# Carry out evaluations in the local (template) frames.
......@@ -631,13 +640,12 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup,
# interrupting the followed computation.
mapped_sources = np.vstack(
[usc.get(queue) for usc in unit_sources])
[usc.get(actx.queue) for usc in unit_sources])
basis_funcs = degroup.basis()
arr_ctx = PyOpenCLArrayContext(queue)
dof_vec_view = unflatten(
arr_ctx, elements_to_sources_lookup.discr, dof_vec)[0]
actx, elements_to_sources_lookup.discr, dof_vec)[0]
dof_vec_view = dof_vec_view.get()
sym_shape = dof_vec.shape[:-1]
......@@ -666,7 +674,7 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup,
source_vec[sym_id + (source_ids_in_el, )] = \
rsplm @ local_dof_vec[sym_id]
source_vec = cl.array.to_device(queue, source_vec)
source_vec = cl.array.to_device(actx.queue, source_vec)
if order == "tree":
pass # no need to do anything
......@@ -682,7 +690,7 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup,
# {{{ to meshmode interpolation
def interpolate_to_meshmode(queue, potential, leaves_to_nodes_lookup,
def interpolate_to_meshmode(actx, potential, leaves_to_nodes_lookup,
order="tree"):
"""
:arg potential: a DoF vector representing a field in :mod:`volumential`,
......@@ -700,9 +708,16 @@ def interpolate_to_meshmode(queue, potential, leaves_to_nodes_lookup,
else:
raise ValueError(f"order must be 'tree' or 'user' (got {order}).")
arr_ctx = PyOpenCLArrayContext(queue)
target_points = flatten(thaw(
arr_ctx, leaves_to_nodes_lookup.discr.nodes()))
if not isinstance(actx, PyOpenCLArrayContext):
if isinstance(actx, cl.CommandQueue):
from warnings import warn
warn("Command queue passed to the interpolator. "
"Supply an array context to enable proper caching.")
actx = PyOpenCLArrayContext(actx)
else:
raise ValueError
target_points = flatten(thaw(actx, leaves_to_nodes_lookup.discr.nodes()))
traversal = leaves_to_nodes_lookup.trav
tree = leaves_to_nodes_lookup.trav.tree
......@@ -721,7 +736,7 @@ def interpolate_to_meshmode(queue, potential, leaves_to_nodes_lookup,
target_points=target_points, traversal=traversal,
wrangler=None, potential=potential,
potential_in_tree_order=potential_in_tree_order,
dim=dim, tree=tree, queue=queue, q_order=q_order,
dim=dim, tree=tree, queue=actx.queue, q_order=q_order,
dtype=potential.dtype, lbl_lookup=None,
balls_near_box_starts=leaves_to_nodes_lookup.nodes_in_leaf_starts,
balls_near_box_lists=leaves_to_nodes_lookup.nodes_in_leaf_lists)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment