diff --git a/volumential/interpolation.py b/volumential/interpolation.py index 85191f93c1dd0fe7c0e7ed7988b28dd790eae994..4559d5f8da51567700480fadb97bf40c25877149 100644 --- a/volumential/interpolation.py +++ b/volumential/interpolation.py @@ -542,7 +542,7 @@ def invert_affine_transform(mat_a, disp_b): # {{{ from meshmode interpolation -def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup, +def interpolate_from_meshmode(actx, dof_vec, elements_to_sources_lookup, order="tree"): """Interpolate a DoF vector from :mod:`meshmode`. @@ -565,6 +565,15 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup, if not isinstance(dof_vec, cl.array.Array): raise TypeError("non-array passed to interpolator") + if not isinstance(actx, PyOpenCLArrayContext): + if isinstance(actx, cl.CommandQueue): + from warnings import warn + warn("Command queue passed to the interpolator. " + "Supply an array context to enable proper caching.") + actx = PyOpenCLArrayContext(actx) + else: + raise ValueError + assert len(elements_to_sources_lookup.discr.groups) == 1 assert len(elements_to_sources_lookup.discr.mesh.groups) == 1 degroup = elements_to_sources_lookup.discr.groups[0] @@ -590,10 +599,10 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup, # mapped source points. sources_in_element_starts = \ - elements_to_sources_lookup.sources_in_element_starts.get(queue) + elements_to_sources_lookup.sources_in_element_starts.get(actx.queue) sources_in_element_lists = \ - elements_to_sources_lookup.sources_in_element_lists.get(queue) - tree = elements_to_sources_lookup.tree.get(queue) + elements_to_sources_lookup.sources_in_element_lists.get(actx.queue) + tree = elements_to_sources_lookup.tree.get(actx.queue) unit_sources_host = make_obj_array( [np.zeros_like(srccrd) for srccrd in tree.sources]) @@ -615,7 +624,7 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup, ivmapped_el_sources[iaxis, :] unit_sources = make_obj_array( - [cl.array.to_device(queue, usc) for usc in unit_sources_host]) + [cl.array.to_device(actx.queue, usc) for usc in unit_sources_host]) # ----------------------------------------------------- # Carry out evaluations in the local (template) frames. @@ -631,13 +640,12 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup, # interrupting the followed computation. mapped_sources = np.vstack( - [usc.get(queue) for usc in unit_sources]) + [usc.get(actx.queue) for usc in unit_sources]) basis_funcs = degroup.basis() - arr_ctx = PyOpenCLArrayContext(queue) dof_vec_view = unflatten( - arr_ctx, elements_to_sources_lookup.discr, dof_vec)[0] + actx, elements_to_sources_lookup.discr, dof_vec)[0] dof_vec_view = dof_vec_view.get() sym_shape = dof_vec.shape[:-1] @@ -666,7 +674,7 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup, source_vec[sym_id + (source_ids_in_el, )] = \ rsplm @ local_dof_vec[sym_id] - source_vec = cl.array.to_device(queue, source_vec) + source_vec = cl.array.to_device(actx.queue, source_vec) if order == "tree": pass # no need to do anything @@ -682,7 +690,7 @@ def interpolate_from_meshmode(queue, dof_vec, elements_to_sources_lookup, # {{{ to meshmode interpolation -def interpolate_to_meshmode(queue, potential, leaves_to_nodes_lookup, +def interpolate_to_meshmode(actx, potential, leaves_to_nodes_lookup, order="tree"): """ :arg potential: a DoF vector representing a field in :mod:`volumential`, @@ -700,9 +708,16 @@ def interpolate_to_meshmode(queue, potential, leaves_to_nodes_lookup, else: raise ValueError(f"order must be 'tree' or 'user' (got {order}).") - arr_ctx = PyOpenCLArrayContext(queue) - target_points = flatten(thaw( - arr_ctx, leaves_to_nodes_lookup.discr.nodes())) + if not isinstance(actx, PyOpenCLArrayContext): + if isinstance(actx, cl.CommandQueue): + from warnings import warn + warn("Command queue passed to the interpolator. " + "Supply an array context to enable proper caching.") + actx = PyOpenCLArrayContext(actx) + else: + raise ValueError + + target_points = flatten(thaw(actx, leaves_to_nodes_lookup.discr.nodes())) traversal = leaves_to_nodes_lookup.trav tree = leaves_to_nodes_lookup.trav.tree @@ -721,7 +736,7 @@ def interpolate_to_meshmode(queue, potential, leaves_to_nodes_lookup, target_points=target_points, traversal=traversal, wrangler=None, potential=potential, potential_in_tree_order=potential_in_tree_order, - dim=dim, tree=tree, queue=queue, q_order=q_order, + dim=dim, tree=tree, queue=actx.queue, q_order=q_order, dtype=potential.dtype, lbl_lookup=None, balls_near_box_starts=leaves_to_nodes_lookup.nodes_in_leaf_starts, balls_near_box_lists=leaves_to_nodes_lookup.nodes_in_leaf_lists)