diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 209820858612b5aa639845f9383185422fdd203c..2eb6ccad3072a0738e5fba6fa6524cf652a7623d 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -278,16 +278,20 @@ class CompiledFunction: input_kwargs_to_loopy = {} - # {{{ extract loopy arguments execute the program + # {{{ preprocess args to get arguments (CL buffers) to be fed to the + # loopy program for arg_id, arg in arg_id_to_arg.items(): if np.isscalar(arg): arg = cla.to_device(self.actx.queue, np.array(arg)) elif isinstance(arg, pt.array.DataWrapper): + # got a Datwwrapper => simply gets its data arg = arg.data elif isinstance(arg, cla.Array): + # got a frozen array => do nothing pass elif isinstance(arg, pt.Array): + # got an array expression => evaluate it arg = self.actx.freeze(arg).with_queue(self.actx.queue) else: raise NotImplementedError(type(arg))