diff --git a/arraycontext/context.py b/arraycontext/context.py
index de4aa69e3810c357bc5af76050384a42602d02db..411c5312651a94382dc15a52f8a4d4642d3b7a40 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -102,13 +102,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Sequence, Union, Callable, Any, Tuple
+from typing import Sequence, Union, Callable, Any
 from abc import ABC, abstractmethod
 
 import numpy as np
 from pytools import memoize_method
 from pytools.tag import Tag
-from numbers import Number
 
 
 # {{{ ArrayContext
@@ -351,9 +350,7 @@ class ArrayContext(ABC):
             "setup-only" array context "leaks" into the application.
         """
 
-    def compile(self, f: Callable[[Any], Any],
-            inputs_like: Tuple[Union[Number, np.ndarray], ...]) -> Callable[
-                ..., Any]:
+    def compile(self, f: Callable[[Any], Any]) -> Callable[..., Any]:
         """Compiles *f* for repeated use on this array context. *f* is expected
         to be a `pure function <https://en.wikipedia.org/wiki/Pure_function>`__
         performing an array computation.
diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py
index e8f4b44cf40abcc24021aace3fb341cbfbd00e00..612f4a9368ef0e8b7932b75ced1338e20ea34d34 100644
--- a/arraycontext/impl/pytato.py
+++ b/arraycontext/impl/pytato.py
@@ -26,17 +26,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-
 from arraycontext.fake_numpy import \
         BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
 from arraycontext.context import ArrayContext
 from arraycontext.container.traversal import \
         rec_multimap_array_container, rec_map_array_container
 import numpy as np
-from typing import Any, Callable, Tuple, Union, Sequence
+from typing import Any, Callable, Tuple, Union, Sequence, Mapping
 from pytools.tag import Tag
-from numbers import Number
 import loopy as lp
+from dataclasses import dataclass, field
+from pyrsistent import pmap, PMap
 
 
 class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
@@ -167,10 +167,127 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
                                  f"(got {order})")
 
         return rec_map_array_container(_rec_ravel, a)
+
     # }}}
 
 
+class AbstractInputDescriptor:
+    def __eq__(self, other):
+        raise NotImplementedError
+
+    def __hash__(self, other):
+        raise NotImplementedError
+
+
+@dataclass(frozen=True, eq=True)
+class ScalarInputDescriptor(AbstractInputDescriptor):
+    dtype: np.dtype
+
+
+@dataclass(frozen=True, eq=True)
+class ArrayContainerInputDescriptor(AbstractInputDescriptor):
+    id_to_ary_descr: "PMap[Tuple[Union[str, int], ...], Tuple[np.dtype, \
+                                                        Tuple[int, ...]]]"
+
+
+@dataclass
 class PytatoCompiledOperator:
+    actx: ArrayContext
+    f: Callable[[Any], Any]
+    program_cache: Mapping[Tuple[AbstractInputDescriptor],
+                           "PytatoExecutable"] = field(default_factory=lambda: {})
+
+    def __call__(self, *args):
+
+        from arraycontext import (rec_keyed_map_array_container,
+                                  is_array_container)
+        import pytato as pt
+
+        def to_arg_descr(arg):
+            if np.isscalar(arg):
+                return ScalarInputDescriptor(np.dtype(arg))
+            elif is_array_container(arg):
+                id_to_ary_descr = {}
+
+                def id_collector(keys, ary):
+                    id_to_ary_descr[keys] = (np.dtype(ary.dtype),
+                                             ary.shape)
+                    return ary
+
+                rec_keyed_map_array_container(id_collector, arg)
+                return ArrayContainerInputDescriptor(pmap(id_to_ary_descr))
+            else:
+                raise ValueError("Argument to a compiled operator should be"
+                                 " either a scalar or an array container. Got"
+                                 f" '{arg}'.")
+
+        arg_descrs = tuple(to_arg_descr(arg) for arg in args)
+
+        try:
+            exec_f = self.program_cache[arg_descrs]
+        except KeyError:
+            pass
+        else:
+            return exec_f(*args)
+
+        dict_of_named_arrays = {}
+        # output_naming_map: result id to name of the named array in the
+        # generated pytato DAG.
+        output_naming_map = {}
+        # input_naming_map: argument id to placeholder name in the generated
+        # pytato DAG.
+        input_naming_map = {}
+
+        def to_placeholder(arg, pos):
+            if np.isscalar(arg):
+                name = f"_actx_in_{pos}"
+                input_naming_map[(pos, )] = name
+                return pt.make_placeholder((), np.dtype(arg), name)
+            elif is_array_container(arg):
+                def _rec_to_placeholder(keys, ary):
+                    name = f"_actx_in_{pos}_" + "_".join(str(key)
+                                                         for key in keys)
+                    input_naming_map[(pos,) + keys] = name
+                    return pt.make_placeholder(ary.shape, ary.dtype,
+                                               name)
+                return rec_keyed_map_array_container(_rec_to_placeholder,
+                                                     arg)
+            else:
+                raise NotImplementedError(type(arg))
+
+        outputs = self.f(*[to_placeholder(arg, iarg)
+                           for iarg, arg in enumerate(args)])
+
+        if not is_array_container(outputs):
+            # TODO: We could possibly just short-circuit this interface if the
+            # returned type is a scalar. Not sure if it's worth it though.
+            raise ValueError(f"Function {self.f.__name__} to be compiled did not"
+                             " return an array container.")
+
+        def _as_dict_of_named_arrays(keys, ary):
+            name = "_pt_out_" + "_".join(str(key)
+                                         for key in keys)
+            output_naming_map[keys] = name
+            dict_of_named_arrays[name] = ary
+            return ary
+
+        rec_keyed_map_array_container(_as_dict_of_named_arrays,
+                                      outputs)
+
+        pytato_program = pt.generate_loopy(dict_of_named_arrays,
+                                           options={"return_dict": True},
+                                           cl_device=self.actx.queue.device)
+
+        self.program_cache[arg_descrs] = PytatoExecutable(self.actx,
+                                                          pytato_program,
+                                                          input_naming_map,
+                                                          output_naming_map,
+                                                          output_template=outputs)
+
+        return self.program_cache[arg_descrs](*args)
+
+
+class PytatoExecutable:
     def __init__(self, actx, pytato_program, input_id_to_name_in_program,
                  output_id_to_name_in_program, output_template):
         self.actx = actx
@@ -329,67 +446,8 @@ class PytatoArrayContext(ArrayContext):
 
     # }}}
 
-    def compile(self, f: Callable[[Any], Any],
-                inputs_like: Tuple[Union[Number, np.ndarray], ...]
-                ) -> Callable[..., Any]:
-        from arraycontext import (rec_keyed_map_array_container,
-                                  is_array_container)
-        import pytato as pt
-
-        dict_of_named_arrays = {}
-        output_naming_map = {}
-        input_naming_map = {}
-
-        def to_placeholder(input_like, pos):
-            if isinstance(input_like, np.number):
-                name = f"_actx_in_{pos}"
-                input_naming_map[(pos, )] = name
-                return pt.make_placeholder((), input_like.dtype, name)
-            elif is_array_container(input_like):
-                def _rec_to_placeholder(keys, ary):
-                    name = f"_actx_in_{pos}_" + "_".join(str(key)
-                                                       for key in keys)
-                    input_naming_map[(pos,) + keys] = name
-                    return pt.make_placeholder(ary.shape, ary.dtype,
-                                               name)
-                return rec_keyed_map_array_container(_rec_to_placeholder,
-                                                     input_like)
-            else:
-                raise NotImplementedError("Unknown input type "
-                                          f"'{type(input_like)}'.")
-
-        outputs = f(*[to_placeholder(el, iel)
-                      for iel, el in enumerate(inputs_like)])
-
-        if not is_array_container(outputs):
-            # TODO: We could possibly just short-circuit this interface if the
-            # returned type is a scalar. Not sure if it's worth it though.
-            raise ValueError("Function to be compiled did not return an array"
-                             " container.")
-
-        def _as_dict_of_named_arrays(keys, ary):
-            name = "_pt_out_" + "_".join(str(key)
-                                          for key in keys)
-            output_naming_map[keys] = name
-            dict_of_named_arrays[name] = ary
-            return ary
-
-        rec_keyed_map_array_container(_as_dict_of_named_arrays,
-                                      outputs)
-
-        pytato_program = pt.generate_loopy(dict_of_named_arrays,
-                                           options={"return_dict": True},
-                                           cl_device=self.queue.device)
-
-        if False:
-            # transforming leads to compile-time slow downs (turning off for now)
-            pytato_program.program = self.transform_loopy_program(pytato_program
-                                                                  .program)
-
-        return PytatoCompiledOperator(self, pytato_program,
-                                      input_naming_map,
-                                      output_naming_map,
-                                      output_template=outputs)
+    def compile(self, f: Callable[[Any], Any]) -> Callable[..., Any]:
+        return PytatoCompiledOperator(self, f)
 
     def transform_loopy_program(self, prg):
         from loopy.translation_unit import for_each_kernel