From 1ee78297aa2b254ff2660be6dcd4de90123e0d4c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 21 May 2016 19:21:18 +0200
Subject: [PATCH] Support for global temporaries

---
 doc/ref_transform.rst         |   2 +
 loopy/__init__.py             |   5 +-
 loopy/auto_test.py            |  13 +++-
 loopy/check.py                |  48 ++++++++------
 loopy/codegen/__init__.py     |   9 ++-
 loopy/compiled.py             |  25 ++++++-
 loopy/kernel/__init__.py      |  25 +++++--
 loopy/kernel/creation.py      |   4 +-
 loopy/kernel/data.py          |  69 ++++++++++++++------
 loopy/preprocess.py           | 118 +++++++++++++++++++++-------------
 loopy/target/c/__init__.py    |  19 ++++--
 loopy/target/pyopencl.py      |  54 ++++++++++++++--
 loopy/transform/buffer.py     |  33 ++++++++--
 loopy/transform/data.py       |  46 ++++++++++++-
 loopy/transform/precompute.py |  45 +++++++++----
 loopy/version.py              |   2 +-
 test/test_loopy.py            |  30 ++++++++-
 17 files changed, 415 insertions(+), 132 deletions(-)

diff --git a/doc/ref_transform.rst b/doc/ref_transform.rst
index 386fbc18a..fcd470dc2 100644
--- a/doc/ref_transform.rst
+++ b/doc/ref_transform.rst
@@ -96,6 +96,8 @@ Modifying Arguments
 
 .. autofunction:: rename_argument
 
+.. autofunction:: set_temporary_scope
+
 Creating Batches of Operations
 ------------------------------
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index fc3fb208a..19d9ddbc7 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -78,7 +78,8 @@ from loopy.transform.data import (
         add_prefetch, change_arg_to_image, tag_data_axes,
         set_array_dim_names, remove_unused_arguments,
         alias_temporaries, set_argument_order,
-        rename_argument)
+        rename_argument,
+        set_temporary_scope)
 
 from loopy.transform.subst import (extract_subst,
         assignment_to_subst, expand_subst, find_rules_matching,
@@ -166,7 +167,7 @@ __all__ = [
         "add_prefetch", "change_arg_to_image", "tag_data_axes",
         "set_array_dim_names", "remove_unused_arguments",
         "alias_temporaries", "set_argument_order",
-        "rename_argument",
+        "rename_argument", "set_temporary_scope",
 
         "find_instructions", "map_instructions",
         "set_instruction_priority", "add_dependency",
diff --git a/loopy/auto_test.py b/loopy/auto_test.py
index 0adf4416d..bada80328 100644
--- a/loopy/auto_test.py
+++ b/loopy/auto_test.py
@@ -79,7 +79,7 @@ def make_ref_args(kernel, impl_arg_info, queue, parameters):
     import pyopencl as cl
     import pyopencl.array as cl_array
 
-    from loopy.kernel.data import ValueArg, GlobalArg, ImageArg
+    from loopy.kernel.data import ValueArg, GlobalArg, ImageArg, TemporaryVariable
 
     from pymbolic import evaluate
 
@@ -177,6 +177,11 @@ def make_ref_args(kernel, impl_arg_info, queue, parameters):
                         ref_alloc_size=alloc_size,
                         ref_numpy_strides=numpy_strides,
                         needs_checking=is_output))
+
+        elif arg.arg_class is TemporaryVariable:
+            # global temporary, handled by invocation logic
+            pass
+
         else:
             raise LoopyError("arg type not understood")
 
@@ -191,7 +196,7 @@ def make_args(kernel, impl_arg_info, queue, ref_arg_data, parameters):
     import pyopencl as cl
     import pyopencl.array as cl_array
 
-    from loopy.kernel.data import ValueArg, GlobalArg, ImageArg
+    from loopy.kernel.data import ValueArg, GlobalArg, ImageArg, TemporaryVariable
 
     from pymbolic import evaluate
 
@@ -275,6 +280,10 @@ def make_args(kernel, impl_arg_info, queue, ref_arg_data, parameters):
             arg_desc.test_numpy_strides = numpy_strides
             arg_desc.test_alloc_size = alloc_size
 
+        elif arg.arg_class is TemporaryVariable:
+            # global temporary, handled by invocation logic
+            pass
+
         else:
             raise LoopyError("arg type not understood")
 
diff --git a/loopy/check.py b/loopy/check.py
index 0ef3d27cf..910ab24ab 100644
--- a/loopy/check.py
+++ b/loopy/check.py
@@ -127,9 +127,35 @@ def check_for_inactive_iname_access(kernel):
                     % insn.id)
 
 
+def _is_racing_iname_tag(tv, tag):
+    from loopy.kernel.data import (temp_var_scope,
+            LocalIndexTagBase, GroupIndexTag, ParallelTag, auto)
+
+    if tv.scope == temp_var_scope.PRIVATE:
+        return (
+                isinstance(tag, ParallelTag)
+                and not isinstance(tag, (LocalIndexTagBase, GroupIndexTag)))
+
+    elif tv.scope == temp_var_scope.LOCAL:
+        return (
+                isinstance(tag, ParallelTag)
+                and not isinstance(tag, GroupIndexTag))
+
+    elif tv.scope == temp_var_scope.GLOBAL:
+        return isinstance(tag, ParallelTag)
+
+    elif tv.scope == auto:
+        raise LoopyError("scope of temp var '%s' has not yet been"
+                "determined" % tv.name)
+
+    else:
+        raise ValueError("unexpected value of temp_var.scope for "
+                "temporary variable '%s'" % tv.name)
+
+
 def check_for_write_races(kernel):
     from loopy.symbolic import DependencyMapper
-    from loopy.kernel.data import ParallelTag, GroupIndexTag, LocalIndexTagBase
+    from loopy.kernel.data import ParallelTag
     depmap = DependencyMapper(composite_leaves=False)
 
     iname_to_tag = kernel.iname_to_tag.get
@@ -162,26 +188,10 @@ def check_for_write_races(kernel):
 
             elif assignee_name in kernel.temporary_variables:
                 temp_var = kernel.temporary_variables[assignee_name]
-                if temp_var.is_local is True:
-                    raceable_parallel_insn_inames = set(
-                            iname
-                            for iname in kernel.insn_inames(insn)
-                            if isinstance(iname_to_tag(iname), ParallelTag)
-                            and not isinstance(iname_to_tag(iname), GroupIndexTag))
-
-                elif temp_var.is_local is False:
-                    raceable_parallel_insn_inames = set(
+                raceable_parallel_insn_inames = set(
                             iname
                             for iname in kernel.insn_inames(insn)
-                            if isinstance(iname_to_tag(iname), ParallelTag)
-                            and not isinstance(iname_to_tag(iname),
-                                GroupIndexTag)
-                            and not isinstance(iname_to_tag(iname),
-                                LocalIndexTagBase))
-
-                else:
-                    raise LoopyError("temp var '%s' hasn't decided on "
-                            "whether it is local" % temp_var.name)
+                            if _is_racing_iname_tag(temp_var, iname_to_tag(iname)))
 
             else:
                 raise LoopyError("invalid assignee name in instruction '%s'"
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 6eef793c7..5e3d11396 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -407,7 +407,7 @@ def generate_code_v2(kernel):
 
     # {{{ examine arg list
 
-    from loopy.kernel.data import ValueArg
+    from loopy.kernel.data import ValueArg, temp_var_scope
     from loopy.kernel.array import ArrayBase
 
     implemented_data_info = []
@@ -432,6 +432,13 @@ def generate_code_v2(kernel):
         else:
             raise ValueError("argument type not understood: '%s'" % type(arg))
 
+    for tv in six.itervalues(kernel.temporary_variables):
+        if tv.scope == temp_var_scope.GLOBAL:
+            implemented_data_info.extend(
+                    tv.decl_info(
+                        kernel.target,
+                        index_dtype=kernel.index_dtype))
+
     allow_complex = False
     for var in kernel.args + list(six.itervalues(kernel.temporary_variables)):
         if var.dtype.involves_complex():
diff --git a/loopy/compiled.py b/loopy/compiled.py
index 55feff66a..900ed2ba3 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -302,6 +302,7 @@ def generate_integer_arg_finding_from_strides(gen, kernel, implemented_data_info
 def generate_arg_setup(gen, kernel, implemented_data_info, options):
     import loopy as lp
 
+    from loopy.kernel.data import KernelArgument
     from loopy.kernel.array import ArrayBase
     from loopy.symbolic import StringifyMapper
     from pymbolic import var
@@ -318,10 +319,20 @@ def generate_arg_setup(gen, kernel, implemented_data_info, options):
 
     strify = StringifyMapper()
 
+    expect_no_more_arguments = False
+
     for arg_idx, arg in enumerate(implemented_data_info):
         is_written = arg.base_name in kernel.get_written_variables()
         kernel_arg = kernel.impl_arg_to_arg.get(arg.name)
 
+        if not issubclass(arg.arg_class, KernelArgument):
+            expect_no_more_arguments = True
+            continue
+
+        if expect_no_more_arguments:
+            raise LoopyError("Further arguments encountered after arg info "
+                    "describing a global temporary variable")
+
         if not issubclass(arg.arg_class, ArrayBase):
             args.append(arg.name)
             continue
@@ -552,9 +563,14 @@ def generate_invoker(kernel, codegen_result):
             "out_host=None"
             ]
 
+    from loopy.kernel.data import KernelArgument
     gen = PythonFunctionGenerator(
             "invoke_%s_loopy_kernel" % kernel.name,
-            system_args + ["%s=None" % iai.name for iai in implemented_data_info])
+            system_args + [
+                "%s=None" % idi.name
+                for idi in implemented_data_info
+                if issubclass(idi.arg_class, KernelArgument)
+                ])
 
     gen.add_to_preamble("from __future__ import division")
     gen.add_to_preamble("")
@@ -600,7 +616,10 @@ def generate_invoker(kernel, codegen_result):
         gen("if out_host:")
         with Indentation(gen):
             gen("pass")  # if no outputs (?!)
-            for arg_idx, arg in enumerate(implemented_data_info):
+            for arg in implemented_data_info:
+                if not issubclass(arg.arg_class, KernelArgument):
+                    continue
+
                 is_written = arg.base_name in kernel.get_written_variables()
                 if is_written:
                     gen("%s = %s.get(queue=queue)" % (arg.name, arg.name))
@@ -611,10 +630,12 @@ def generate_invoker(kernel, codegen_result):
         gen("return _lpy_evt, {%s}"
                 % ", ".join("\"%s\": %s" % (arg.name, arg.name)
                     for arg in implemented_data_info
+                    if issubclass(arg.arg_class, KernelArgument)
                     if arg.base_name in kernel.get_written_variables()))
     else:
         out_args = [arg
                 for arg in implemented_data_info
+                    if issubclass(arg.arg_class, KernelArgument)
                 if arg.base_name in kernel.get_written_variables()]
         if out_args:
             gen("return _lpy_evt, (%s,)"
diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py
index 9b2c896ac..5ac63b56e 100644
--- a/loopy/kernel/__init__.py
+++ b/loopy/kernel/__init__.py
@@ -843,9 +843,17 @@ class LoopKernel(RecordWithoutPickling):
 
     @memoize_method
     def global_var_names(self):
+        from loopy.kernel.data import temp_var_scope
+
         from loopy.kernel.data import GlobalArg
-        return set(arg.name for arg in self.args
-            if isinstance(arg, GlobalArg))
+        return (
+                set(
+                    arg.name for arg in self.args
+                    if isinstance(arg, GlobalArg))
+                | set(
+                    tv.name
+                    for tv in six.itervalues(self.temporary_variables)
+                    if tv.scope == temp_var_scope.GLOBAL))
 
     # }}}
 
@@ -1033,14 +1041,17 @@ class LoopKernel(RecordWithoutPickling):
 
     @memoize_method
     def local_var_names(self):
+        from loopy.kernel.data import temp_var_scope
         return set(
             tv.name
             for tv in six.itervalues(self.temporary_variables)
-            if tv.is_local)
+            if tv.scope == temp_var_scope.LOCAL)
 
     def local_mem_use(self):
-        return sum(lv.nbytes for lv in six.itervalues(self.temporary_variables)
-                if lv.is_local)
+        from loopy.kernel.data import temp_var_scope
+        return sum(
+                tv.nbytes for tv in six.itervalues(self.temporary_variables)
+                if tv.scope == temp_var_scope.LOCAL)
 
     # }}}
 
@@ -1213,8 +1224,8 @@ class LoopKernel(RecordWithoutPickling):
         return CompiledKernel(ctx, self)
 
     def __call__(self, queue, **kwargs):
-        return self.get_compiled_kernel(queue.context)(
-                queue, **kwargs)
+        cknl = self.get_compiled_kernel(queue.context)
+        return cknl(queue, **kwargs)
 
     # }}}
 
diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index 9fe0f5b79..e25320729 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -793,7 +793,7 @@ def expand_cses(instructions, cse_prefix="cse_expr"):
         new_temp_vars.append(TemporaryVariable(
                 name=new_var_name,
                 dtype=dtype,
-                is_local=lp.auto,
+                scope=lp.auto,
                 shape=()))
 
         from pymbolic.primitives import Variable
@@ -857,7 +857,7 @@ def create_temporaries(knl, default_order):
                 new_temp_vars[assignee_name] = lp.TemporaryVariable(
                         name=assignee_name,
                         dtype=temp_var_type,
-                        is_local=lp.auto,
+                        scope=lp.auto,
                         base_indices=lp.auto,
                         shape=lp.auto,
                         order=default_order)
diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py
index 8de5919df..9c399997d 100644
--- a/loopy/kernel/data.py
+++ b/loopy/kernel/data.py
@@ -275,19 +275,34 @@ class temp_var_scope:
     .. attribute:: GLOBAL
     """
 
+    # These must occur in ascending order of 'globality' so that
+    # max(scope) does the right thing.
+
     PRIVATE = 0
     LOCAL = 1
     GLOBAL = 2
 
+    @classmethod
+    def stringify(cls, val):
+        if val == cls.PRIVATE:
+            return "private"
+        elif val == cls.LOCAL:
+            return "local"
+        elif val == cls.GLOBAL:
+            return "global"
+        else:
+            raise ValueError("unexpected value of temp_var_scope")
+
 
 class TemporaryVariable(ArrayBase):
     __doc__ = ArrayBase.__doc__ + """
     .. attribute:: storage_shape
     .. attribute:: base_indices
-    .. attribute:: is_local
+    .. attribute:: scope
 
-        Whether this is temporary lives in ``local`` memory.
-        May be *True*, *False*, or :class:`loopy.auto` if this is
+        What memory this temporary variable lives in.
+        One of the values in :class:`temp_var_scope`,
+        or :class:`loopy.auto` if this is
         to be automatically determined.
 
     .. attribute:: base_storage
@@ -304,11 +319,11 @@ class TemporaryVariable(ArrayBase):
     allowed_extra_kwargs = [
             "storage_shape",
             "base_indices",
-            "is_local",
+            "scope",
             "base_storage"
             ]
 
-    def __init__(self, name, dtype=None, shape=(), is_local=auto,
+    def __init__(self, name, dtype=None, shape=(), scope=auto,
             dim_tags=None, offset=0, dim_names=None, strides=None, order=None,
             base_indices=None, storage_shape=None,
             base_storage=None):
@@ -318,10 +333,6 @@ class TemporaryVariable(ArrayBase):
         :arg base_indices: :class:`loopy.auto` or a tuple of base indices
         """
 
-        if is_local is None:
-            raise ValueError("is_local is None is no longer supported. "
-                    "Use loopy.auto.")
-
         if base_indices is None:
             base_indices = (0,) * len(shape)
 
@@ -329,18 +340,25 @@ class TemporaryVariable(ArrayBase):
                 dtype=dtype, shape=shape,
                 dim_tags=dim_tags, offset=offset, dim_names=dim_names,
                 order="C",
-                base_indices=base_indices, is_local=is_local,
+                base_indices=base_indices, scope=scope,
                 storage_shape=storage_shape,
                 base_storage=base_storage)
 
     @property
-    def scope(self):
+    def is_local(self):
         """One of :class:`loopy.temp_var_scope`."""
 
-        if self.is_local:
-            return temp_var_scope.LOCAL
+        if self.scope is auto:
+            return auto
+        elif self.scope == temp_var_scope.LOCAL:
+            return True
+        elif self.scope == temp_var_scope.PRIVATE:
+            return False
+        elif self.scope == temp_var_scope.GLOBAL:
+            raise LoopyError("TemporaryVariable.is_local called on "
+                    "global temporary variable '%s'" % self.name)
         else:
-            return temp_var_scope.PRIVATE
+            raise LoopyError("unexpected value of TemporaryVariable.scope")
 
     @property
     def nbytes(self):
@@ -356,18 +374,31 @@ class TemporaryVariable(ArrayBase):
                 target, is_written=True, index_dtype=index_dtype,
                 shape_override=self.storage_shape)
 
-    def get_arg_decl(self, target, name_suffix, shape, dtype, is_written):
-        return None
+    def get_arg_decl(self, ast_builder, name_suffix, shape, dtype, is_written):
+        if self.scope == temp_var_scope.GLOBAL:
+            return ast_builder.get_global_arg_decl(self.name + name_suffix, shape,
+                    dtype, is_written)
+        else:
+            raise LoopyError("unexpected request for argument declaration of "
+                    "non-global temporary")
 
     def __str__(self):
-        return self.stringify(include_typename=False)
+        if self.scope is auto:
+            scope_str = "auto"
+        else:
+            scope_str = temp_var_scope.stringify(self.scope)
+
+        return (
+                self.stringify(include_typename=False)
+                +
+                " scope:%s" % scope_str)
 
     def __eq__(self, other):
         return (
                 super(TemporaryVariable, self).__eq__(other)
                 and self.storage_shape == other.storage_shape
                 and self.base_indices == other.base_indices
-                and self.is_local == other.is_local
+                and self.scope == other.scope
                 and self.base_storage == other.base_storage)
 
     def update_persistent_hash(self, key_hash, key_builder):
@@ -378,7 +409,7 @@ class TemporaryVariable(ArrayBase):
         super(TemporaryVariable, self).update_persistent_hash(key_hash, key_builder)
         key_builder.rec(key_hash, self.storage_shape)
         key_builder.rec(key_hash, self.base_indices)
-        key_builder.rec(key_hash, self.is_local)
+        key_builder.rec(key_hash, self.scope)
 
 # }}}
 
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 2fdadb48e..8b7be4f04 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -293,30 +293,46 @@ def infer_unknown_types(kernel, expect_completion=False):
 # }}}
 
 
-# {{{ decide which temporaries are local
+# {{{ decide temporary scope
 
-def mark_local_temporaries(kernel):
+def _get_compute_inames_tagged(kernel, insn, tag_base):
+    return set(iname
+            for iname in kernel.insn_inames(insn.id)
+            if isinstance(kernel.iname_to_tag.get(iname), tag_base))
+
+
+def _get_assignee_inames_tagged(kernel, insn, tag_base, tv_name):
+    from loopy.symbolic import get_dependencies
+
+    return set(iname
+            for aname, aindices in insn.assignees_and_indices()
+            for iname in get_dependencies(aindices)
+                & kernel.all_inames()
+            if aname == tv_name
+            if isinstance(kernel.iname_to_tag.get(iname), tag_base))
+
+
+def find_temporary_scope(kernel):
     logger.debug("%s: mark local temporaries" % kernel.name)
 
     new_temp_vars = {}
-    from loopy.kernel.data import LocalIndexTagBase
+    from loopy.kernel.data import (LocalIndexTagBase, GroupIndexTag,
+            temp_var_scope)
     import loopy as lp
 
     writers = kernel.writer_map()
 
-    from loopy.symbolic import get_dependencies
-
     for temp_var in six.itervalues(kernel.temporary_variables):
         # Only fill out for variables that do not yet know if they're
         # local. (I.e. those generated by implicit temporary generation.)
 
-        if temp_var.is_local is not lp.auto:
+        if temp_var.scope is not lp.auto:
             new_temp_vars[temp_var.name] = temp_var
             continue
 
         my_writers = writers.get(temp_var.name, [])
 
-        wants_to_be_local_per_insn = []
+        desired_scope_per_insn = []
         for insn_id in my_writers:
             insn = kernel.id_to_insn[insn_id]
 
@@ -327,54 +343,66 @@ def mark_local_temporaries(kernel):
             # - the instruction is run across more inames (locally) parallel
             #   than are reflected in the assignee indices.
 
-            locparallel_compute_inames = set(iname
-                    for iname in kernel.insn_inames(insn_id)
-                    if isinstance(kernel.iname_to_tag.get(iname), LocalIndexTagBase))
+            locparallel_compute_inames = _get_compute_inames_tagged(
+                    kernel, insn, LocalIndexTagBase)
+
+            locparallel_assignee_inames = _get_assignee_inames_tagged(
+                    kernel, insn, LocalIndexTagBase, temp_var.name)
 
-            locparallel_assignee_inames = set(iname
-                    for aname, aindices in insn.assignees_and_indices()
-                    for iname in get_dependencies(aindices)
-                        & kernel.all_inames()
-                    if aname == temp_var.name
-                    if isinstance(kernel.iname_to_tag.get(iname), LocalIndexTagBase))
+            grpparallel_compute_inames = _get_compute_inames_tagged(
+                    kernel, insn, GroupIndexTag)
+
+            grpparallel_assignee_inames = _get_assignee_inames_tagged(
+                    kernel, insn, GroupIndexTag, temp_var.name)
 
             assert locparallel_assignee_inames <= locparallel_compute_inames
+            assert grpparallel_assignee_inames <= grpparallel_compute_inames
+
+            desired_scope = temp_var_scope.PRIVATE
+            for iname_descr, scope_descr, apin, cpin, scope in [
+                    ("local", "local", locparallel_assignee_inames,
+                        locparallel_compute_inames, temp_var_scope.LOCAL),
+                    ("group", "global", grpparallel_assignee_inames,
+                        grpparallel_compute_inames, temp_var_scope.GLOBAL),
+                    ]:
+
+                if (apin != cpin and bool(locparallel_assignee_inames)):
+                    warn(kernel, "write_race_local(%s)" % insn_id,
+                            "instruction '%s' looks invalid: "
+                            "it assigns to indices based on %s IDs, but "
+                            "its temporary '%s' cannot be made %s because "
+                            "a write race across the iname(s) '%s' would emerge. "
+                            "(Do you need to add an extra iname to your prefetch?)"
+                            % (insn_id, iname_descr, temp_var.name, scope_descr,
+                                ", ".join(cpin - apin)),
+                            WriteRaceConditionWarning)
+
+                if (apin == cpin
+
+                        # doesn't want to be in this scope if there aren't any
+                        # parallel inames of that kind:
+                        and bool(cpin)):
+                    desired_scope = max(desired_scope, scope)
+                    break
+
+            desired_scope_per_insn.append(desired_scope)
 
-            if (locparallel_assignee_inames != locparallel_compute_inames
-                    and bool(locparallel_assignee_inames)):
-                warn(kernel, "write_race_local(%s)" % insn_id,
-                        "instruction '%s' looks invalid: "
-                        "it assigns to indices based on local IDs, but "
-                        "its temporary '%s' cannot be made local because "
-                        "a write race across the iname(s) '%s' would emerge. "
-                        "(Do you need to add an extra iname to your prefetch?)"
-                        % (insn_id, temp_var.name, ", ".join(
-                            locparallel_compute_inames
-                            - locparallel_assignee_inames)),
-                        WriteRaceConditionWarning)
-
-            wants_to_be_local_per_insn.append(
-                    locparallel_assignee_inames == locparallel_compute_inames
-
-                    # doesn't want to be local if there aren't any
-                    # parallel inames:
-                    and bool(locparallel_compute_inames))
-
-        if not wants_to_be_local_per_insn:
+        if not desired_scope_per_insn:
             warn(kernel, "temp_to_write(%s)" % temp_var.name,
                     "temporary variable '%s' never written, eliminating"
                     % temp_var.name, LoopyAdvisory)
 
             continue
 
-        is_local = any(wants_to_be_local_per_insn)
+        overall_scope = max(desired_scope_per_insn)
 
         from pytools import all
-        if not all(wtbl == is_local for wtbl in wants_to_be_local_per_insn):
-            raise LoopyError("not all instructions agree on whether "
-                    "temporary '%s' should be in local memory" % temp_var.name)
+        if not all(iscope == overall_scope for iscope in desired_scope_per_insn):
+            raise LoopyError("not all instructions agree on the "
+                    "the desired scope (private/local/global) of  the "
+                    "temporary '%s'" % temp_var.name)
 
-        new_temp_vars[temp_var.name] = temp_var.copy(is_local=is_local)
+        new_temp_vars[temp_var.name] = temp_var.copy(scope=overall_scope)
 
     return kernel.copy(temporary_variables=new_temp_vars)
 
@@ -486,14 +514,14 @@ def realize_reduction(kernel, insn_id_filter=None):
                 for i in range(ncomp)]
         acc_vars = tuple(var(n) for n in acc_var_names)
 
-        from loopy.kernel.data import TemporaryVariable
+        from loopy.kernel.data import TemporaryVariable, temp_var_scope
 
         for name, dtype in zip(acc_var_names, reduction_dtypes):
             new_temporary_variables[name] = TemporaryVariable(
                     name=name,
                     shape=(),
                     dtype=dtype,
-                    is_local=False)
+                    scope=temp_var_scope.PRIVATE)
 
         outer_insn_inames = temp_kernel.insn_inames(insn)
         bad_inames = frozenset(expr.inames) & outer_insn_inames
@@ -808,7 +836,7 @@ def preprocess_kernel(kernel, device=None):
     from loopy.transform.ilp import add_axes_to_temporaries_for_ilp_and_vec
     kernel = add_axes_to_temporaries_for_ilp_and_vec(kernel)
 
-    kernel = mark_local_temporaries(kernel)
+    kernel = find_temporary_scope(kernel)
     kernel = find_boostability(kernel)
     kernel = limit_boostability(kernel)
 
diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py
index 493cef063..6aca830d9 100644
--- a/loopy/target/c/__init__.py
+++ b/loopy/target/c/__init__.py
@@ -210,15 +210,15 @@ class CASTBuilder(ASTBuilderBase):
             return Const(POD(self, idi.dtype, idi.name))
         else:
             name = idi.base_name or idi.name
-            arg = kernel.arg_dict[name]
+            var_descr = kernel.get_var_descriptor(name)
             from loopy.kernel.data import ArrayBase
-            if isinstance(arg, ArrayBase):
-                return arg.get_arg_decl(
+            if isinstance(var_descr, ArrayBase):
+                return var_descr.get_arg_decl(
                         self,
                         idi.name[len(name):], idi.shape, idi.dtype,
                         idi.is_written)
             else:
-                return arg.get_arg_decl(self)
+                return var_descr.get_arg_decl(self)
 
     def get_function_declaration(self, codegen_state, codegen_result,
             schedule_index):
@@ -234,6 +234,8 @@ class CASTBuilder(ASTBuilderBase):
                             for idi in codegen_state.implemented_data_info])
 
     def get_temporary_decls(self, codegen_state):
+        from loopy.kernel.data import temp_var_scope
+
         kernel = codegen_state.kernel
 
         base_storage_decls = []
@@ -254,9 +256,12 @@ class CASTBuilder(ASTBuilderBase):
 
             if not tv.base_storage:
                 for idi in decl_info:
-                    temp_decls.append(
-                            self.wrap_temporary_decl(
-                                self.get_temporary_decl(kernel, tv, idi), tv.scope))
+                    # global temp vars are mapped to arguments
+                    if tv.scope != temp_var_scope.GLOBAL:
+                        temp_decls.append(
+                                self.wrap_temporary_decl(
+                                    self.get_temporary_decl(
+                                        kernel, tv, idi), tv.scope))
 
             else:
                 offset = 0
diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py
index 779abc02e..72147daf8 100644
--- a/loopy/target/pyopencl.py
+++ b/loopy/target/pyopencl.py
@@ -52,9 +52,11 @@ def adjust_local_temp_var_storage(kernel, device):
 
     new_temp_vars = {}
 
+    from loopy.kernel.data import temp_var_scope
+
     lmem_size = cl_char.usable_local_mem_size(device)
     for temp_var in six.itervalues(kernel.temporary_variables):
-        if not temp_var.is_local:
+        if temp_var.scope != temp_var_scope.LOCAL:
             new_temp_vars[temp_var.name] = \
                     temp_var.copy(storage_shape=temp_var.shape)
             continue
@@ -62,7 +64,8 @@ def adjust_local_temp_var_storage(kernel, device):
         other_loctemp_nbytes = [
                 tv.nbytes
                 for tv in six.itervalues(kernel.temporary_variables)
-                if tv.is_local and tv.name != temp_var.name]
+                if tv.scope == temp_var_scope.LOCAL
+                and tv.name != temp_var.name]
 
         storage_shape = temp_var.storage_shape
 
@@ -450,7 +453,8 @@ def generate_value_arg_setup(kernel, devices, implemented_data_info):
                         'must be supplied")'.format(name=idi.name))))
 
         if idi.dtype.is_integral():
-            gen(Comment("cast to Python int to avoid trouble with struct packing or Boost.Python"))
+            gen(Comment("cast to Python int to avoid trouble "
+                "with struct packing or Boost.Python"))
             if sys.version_info < (3,):
                 py_type = "long"
             else:
@@ -567,20 +571,58 @@ class PyOpenCLPythonASTBuilder(PythonASTBuilderBase):
 
     def get_function_definition(self, codegen_state, codegen_result,
             schedule_index, function_decl, function_body):
+        from loopy.kernel.data import TemporaryVariable
         args = (
                 ["_lpy_cl_kernels", "queue"]
-                + [idi.name for idi in codegen_state.implemented_data_info]
-                + ["wait_for=None"])
+                + [idi.name for idi in codegen_state.implemented_data_info
+                    if not issubclass(idi.arg_class, TemporaryVariable)]
+                + ["wait_for=None", "allocator=None"])
+
+        ecm = self.get_expression_to_code_mapper(codegen_state)
+
+        def alloc_nbytes(idi):
+            return idi.dtype.numpy_dtype.itemsize * (
+                    sum(astrd*(alen-1)
+                        for alen, astrd in zip(idi.unvec_shape, idi.unvec_strides))
+                    + 1)
 
-        from genpy import Function, Suite, ImportAs, Return, FromImport, Line
+        from genpy import (Function, Suite, Import, ImportAs, Return, FromImport,
+                If, Assign, Line, Statement as S)
+        from pymbolic.mapper.stringifier import PREC_NONE
         return Function(
                 codegen_result.current_program(codegen_state).name,
                 args,
                 Suite([
                     FromImport("struct", ["pack as _lpy_pack"]),
                     ImportAs("pyopencl", "_lpy_cl"),
+                    Import("pyopencl.tools"),
+                    Line(),
+                    If("allocator is None",
+                        Assign(
+                            "allocator",
+                            "_lpy_cl_tools.DeferredAllocator(queue.context)")),
+                    Line(),
+                    ] + [
+
+                    # allocate global temporaries
+                    Assign(idi.name, "allocator(%s)"
+                        % ecm(alloc_nbytes(idi), PREC_NONE, "i"))
+                    for idi in codegen_result.implemented_data_info
+                    if issubclass(idi.arg_class, TemporaryVariable)
+
+                    ] + [
                     Line(),
                     function_body,
+                    Line(),
+                    ] + [
+
+                    # free global temporaries
+                    S("%s.release()" % idi.name)
+                    for idi in codegen_result.implemented_data_info
+                    if issubclass(idi.arg_class, TemporaryVariable)
+
+                    ] + [
+                    Line(),
                     Return("_lpy_evt"),
                     ]))
 
diff --git a/loopy/transform/buffer.py b/loopy/transform/buffer.py
index 002d5986a..a7d22b2d0 100644
--- a/loopy/transform/buffer.py
+++ b/loopy/transform/buffer.py
@@ -32,6 +32,7 @@ from pymbolic.mapper.substitutor import make_subst_func
 from pytools.persistent_dict import PersistentDict
 from loopy.tools import LoopyKeyBuilder, PymbolicExpressionHashWrapper
 from loopy.version import DATA_MODEL_VERSION
+from loopy.diagnostic import LoopyError
 
 from pymbolic import var
 
@@ -130,7 +131,8 @@ buffer_array_cache = PersistentDict("loopy-buffer-array-cache-"+DATA_MODEL_VERSI
 # Adding an argument? also add something to the cache_key below.
 def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
         store_expression=None, within=None, default_tag="l.auto",
-        temporary_is_local=None, fetch_bounding_box=False):
+        temporary_scope=None, temporary_is_local=None,
+        fetch_bounding_box=False):
     """
     :arg init_expression: Either *None* (indicating the prior value of the buffered
         array should be read) or an expression optionally involving the
@@ -143,6 +145,27 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
         at all.)
     """
 
+    # {{{ unify temporary_scope / temporary_is_local
+
+    from loopy.kernel.data import temp_var_scope
+    if temporary_is_local is not None:
+        from warnings import warn
+        warn("temporary_is_local is deprecated. Use temporary_scope instead",
+                DeprecationWarning, stacklevel=2)
+
+        if temporary_scope is not None:
+            raise LoopyError("may not specify both temporary_is_local and "
+                    "temporary_scope")
+
+        if temporary_is_local:
+            temporary_scope = temp_var_scope.LOCAL
+        else:
+            temporary_scope = temp_var_scope.PRIVATE
+
+    del temporary_is_local
+
+    # }}}
+
     # {{{ process arguments
 
     if isinstance(init_expression, str):
@@ -181,9 +204,9 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
     else:
         var_shape = ()
 
-    if temporary_is_local is None:
+    if temporary_scope is None:
         import loopy as lp
-        temporary_is_local = lp.auto
+        temporary_scope = lp.auto
 
     # }}}
 
@@ -196,7 +219,7 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
     cache_key = (key_kernel, var_name, tuple(buffer_inames),
             PymbolicExpressionHashWrapper(init_expression),
             PymbolicExpressionHashWrapper(store_expression), within,
-            default_tag, temporary_is_local, fetch_bounding_box)
+            default_tag, temporary_scope, fetch_bounding_box)
 
     if CACHING_ENABLED:
         try:
@@ -312,7 +335,7 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
             dtype=var_descr.dtype,
             base_indices=(0,)*len(abm.non1_storage_shape),
             shape=tuple(abm.non1_storage_shape),
-            is_local=temporary_is_local)
+            scope=temporary_scope)
 
     new_temporary_variables[buf_var_name] = temp_var
 
diff --git a/loopy/transform/data.py b/loopy/transform/data.py
index 02499ded2..3db96712e 100644
--- a/loopy/transform/data.py
+++ b/loopy/transform/data.py
@@ -136,7 +136,8 @@ def _process_footprint_subscripts(kernel, rule_name, sweep_inames,
 
 def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
         default_tag="l.auto", rule_name=None,
-        temporary_name=None, temporary_is_local=None,
+        temporary_name=None,
+        temporary_scope=None, temporary_is_local=None,
         footprint_subscripts=None,
         fetch_bounding_box=False):
     """Prefetch all accesses to the variable *var_name*, with all accesses
@@ -245,7 +246,7 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
             default_tag=default_tag, dtype=arg.dtype,
             fetch_bounding_box=fetch_bounding_box,
             temporary_name=temporary_name,
-            temporary_is_local=temporary_is_local)
+            temporary_scope=temporary_scope, temporary_is_local=temporary_is_local)
 
     # {{{ remove inames that were temporarily added by slice sweeps
 
@@ -529,4 +530,45 @@ def rename_argument(kernel, old_name, new_name, existing_ok=False):
 
 # }}}
 
+
+# {{{ set temporary scope
+
+def set_temporary_scope(kernel, temp_var_names, scope):
+    """
+    :arg temp_var_names: a container with membership checking,
+        or a comma-separated string of variables for which the
+        scope is to be set.
+    :arg scope: One of the values from :class:`temp_var_scope`, or one
+        of the strings ``"private"``, ``"local"``, or ``"global"``.
+    """
+
+    if isinstance(temp_var_names, str):
+        temp_var_names = [s.strip() for s in temp_var_names.split(",")]
+
+    from loopy.kernel.data import temp_var_scope
+    if isinstance(scope, str):
+        try:
+            scope = getattr(temp_var_scope, scope.upper())
+        except AttributeError:
+            raise LoopyError("scope '%s' unknown" % scope)
+
+    if not isinstance(scope, int) or scope not in [
+            temp_var_scope.PRIVATE,
+            temp_var_scope.LOCAL,
+            temp_var_scope.GLOBAL]:
+        raise LoopyError("invalid scope '%s'" % scope)
+
+    new_temp_vars = kernel.temporary_variables.copy()
+    for tv_name in temp_var_names:
+        try:
+            tv = new_temp_vars[tv_name]
+        except KeyError:
+            raise LoopyError("temporary '%s' not found" % tv_name)
+
+        new_temp_vars[tv_name] = tv.copy(scope=scope)
+
+    return kernel.copy(temporary_variables=new_temp_vars)
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py
index 6ea0c06e6..fd6f33efc 100644
--- a/loopy/transform/precompute.py
+++ b/loopy/transform/precompute.py
@@ -240,7 +240,8 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
 def precompute(kernel, subst_use, sweep_inames=[], within=None,
         storage_axes=None, temporary_name=None, precompute_inames=None,
         storage_axis_to_tag={}, default_tag="l.auto", dtype=None,
-        fetch_bounding_box=False, temporary_is_local=None,
+        fetch_bounding_box=False,
+        temporary_scope=None, temporary_is_local=None,
         compute_insn_id=None):
     """Precompute the expression described in the substitution rule determined by
     *subst_use* and store it in a temporary array. A precomputation needs two
@@ -316,6 +317,27 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     eliminated.
     """
 
+    # {{{ unify temporary_scope / temporary_is_local
+
+    from loopy.kernel.data import temp_var_scope
+    if temporary_is_local is not None:
+        from warnings import warn
+        warn("temporary_is_local is deprecated. Use temporary_scope instead",
+                DeprecationWarning, stacklevel=2)
+
+        if temporary_scope is not None:
+            raise LoopyError("may not specify both temporary_is_local and "
+                    "temporary_scope")
+
+        if temporary_is_local:
+            temporary_scope = temp_var_scope.LOCAL
+        else:
+            temporary_scope = temp_var_scope.PRIVATE
+
+    del temporary_is_local
+
+    # }}}
+
     # {{{ check, standardize arguments
 
     if isinstance(sweep_inames, str):
@@ -772,8 +794,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
 
     import loopy as lp
 
-    if temporary_is_local is None:
-        temporary_is_local = lp.auto
+    if temporary_scope is None:
+        temporary_scope = lp.auto
 
     new_temp_shape = tuple(abm.non1_storage_shape)
 
@@ -784,7 +806,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
                 dtype=dtype,
                 base_indices=(0,)*len(new_temp_shape),
                 shape=tuple(abm.non1_storage_shape),
-                is_local=temporary_is_local,
+                scope=temporary_scope,
                 dim_names=non1_storage_axis_names)
 
     else:
@@ -822,19 +844,20 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
 
         temp_var = temp_var.copy(shape=new_temp_shape)
 
-        if temporary_is_local == temp_var.is_local:
+        if temporary_scope == temp_var.scope:
             pass
-        elif temporary_is_local is lp.auto:
-            temporary_is_local = temp_var.is_local
-        elif temp_var.is_local is lp.auto:
+        elif temporary_scope is lp.auto:
+            temporary_scope = temp_var.scope
+        elif temp_var.scope is lp.auto:
             pass
         else:
             raise LoopyError("Existing and new temporary '%s' do not "
-                    "have matching values of 'is_local'"
+                    "have matching scopes (existing: %s, new: %s)"
                     % (temporary_name,
-                        temp_var.is_local, temporary_is_local))
+                        temp_var_scope.stringify(temp_var.scope),
+                        temp_var_scope.stringify(temporary_scope)))
 
-        temp_var = temp_var.copy(is_local=temporary_is_local)
+        temp_var = temp_var.copy(scope=temporary_scope)
 
         # }}}
 
diff --git a/loopy/version.py b/loopy/version.py
index ce1cf3089..cfaa4b9a0 100644
--- a/loopy/version.py
+++ b/loopy/version.py
@@ -32,4 +32,4 @@ except ImportError:
 else:
     _islpy_version = islpy.version.VERSION_TEXT
 
-DATA_MODEL_VERSION = "v32-islpy%s" % _islpy_version
+DATA_MODEL_VERSION = "v33-islpy%s" % _islpy_version
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 588cc9a20..57f90f495 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2614,7 +2614,7 @@ def test_kernel_splitting_with_loop(ctx_factory):
     knl = lp.add_and_infer_dtypes(knl,
             {"a": np.float32, "c": np.float32, "out": np.float32, "n": np.int32})
 
-    ref_knl = knl
+    # ref_knl = knl
 
     knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")
 
@@ -2639,6 +2639,34 @@ def test_kernel_splitting_with_loop(ctx_factory):
     #lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5))
 
 
+def test_global_temporary(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(
+            "{ [i]: 0<=i<n}",
+            """
+            <> c[i] = a[i + 1]
+            out[i] = c[i]
+            """)
+
+    knl = lp.add_and_infer_dtypes(knl,
+            {"a": np.float32, "c": np.float32, "out": np.float32, "n": np.int32})
+    knl = lp.set_temporary_scope(knl, "c", "global")
+
+    ref_knl = knl
+
+    knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")
+
+    cgr = lp.generate_code_v2(knl)
+
+    assert len(cgr.device_programs) == 2
+
+    #print(cgr.device_code())
+    #print(cgr.host_code())
+
+    lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5))
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab