From e68e85523c281dbac32a932c427993104dd30038 Mon Sep 17 00:00:00 2001
From: arghdos <arghdos@gmail.com>
Date: Fri, 18 Aug 2017 14:41:26 -0400
Subject: [PATCH] add support for atomic_init in OpenCL

---
 loopy/target/c/__init__.py |  6 +++++-
 loopy/target/opencl.py     | 12 ++++++++++++
 test/test_loopy.py         | 19 +++++++++++++++++++
 3 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py
index a2ad68250..47130c1f7 100644
--- a/loopy/target/c/__init__.py
+++ b/loopy/target/c/__init__.py
@@ -631,7 +631,11 @@ class CASTBuilder(ASTBuilderBase):
                         needed_dtype=lhs_dtype))
 
         elif isinstance(lhs_atomicity, AtomicInit):
-            raise NotImplementedError("atomic init")
+            codegen_state.seen_atomic_dtypes.add(lhs_dtype)
+            return codegen_state.ast_builder.emit_atomic_init(
+                    codegen_state, lhs_atomicity, lhs_var,
+                    insn.assignee, insn.expression,
+                    lhs_dtype, rhs_type_context)
 
         elif isinstance(lhs_atomicity, AtomicUpdate):
             codegen_state.seen_atomic_dtypes.add(lhs_dtype)
diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py
index a5f7562c4..95299ef52 100644
--- a/loopy/target/opencl.py
+++ b/loopy/target/opencl.py
@@ -507,6 +507,18 @@ class OpenCLCASTBuilder(CASTBuilder):
 
         return CLConstant(arg_decl)
 
+    # {{{
+
+    def emit_atomic_init(self, codegen_state, lhs_atomicity, lhs_var,
+            lhs_expr, rhs_expr, lhs_dtype, rhs_type_context):
+        # for the CL1 flavor, this is as simple as a regular update with whatever
+        # the RHS value is...
+
+        return self.emit_atomic_update(codegen_state, lhs_atomicity, lhs_var,
+            lhs_expr, rhs_expr, lhs_dtype, rhs_type_context)
+
+    # }}}
+
     # {{{ code generation for atomic update
 
     def emit_atomic_update(self, codegen_state, lhs_atomicity, lhs_var,
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 0aff90fd2..f11230b1f 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1064,6 +1064,25 @@ def test_atomic_load(ctx_factory):
     assert np.allclose(out, np.full_like(out, (-(2 * n - 1) / float(3 * vec_width))))
 
 
+def test_atomic_init():
+    dtype = np.float32
+    vec_width = 4
+
+    knl = lp.make_kernel(
+            "{ [i,j]: 0<=i<100 }",
+            """
+            out[i%4] = 0 {id=init, atomic=init}
+            """,
+            [
+                lp.GlobalArg("out", dtype, shape=lp.auto, for_atomic=True),
+                "..."
+                ],
+            silenced_warnings=["write_race(init)"])
+    knl = lp.split_iname(knl, 'i', vec_width, inner_tag='l.0')
+
+    print(lp.generate_code_v2(knl).device_code())
+
+
 def test_within_inames_and_reduction():
     # See https://github.com/inducer/loopy/issues/24
 
-- 
GitLab