diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index a2ad682505bbdb7ed5977a28e201ebc6655c7784..47130c1f764278979ea005cfd6ec101819a2e13c 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -631,7 +631,11 @@ class CASTBuilder(ASTBuilderBase): needed_dtype=lhs_dtype)) elif isinstance(lhs_atomicity, AtomicInit): - raise NotImplementedError("atomic init") + codegen_state.seen_atomic_dtypes.add(lhs_dtype) + return codegen_state.ast_builder.emit_atomic_init( + codegen_state, lhs_atomicity, lhs_var, + insn.assignee, insn.expression, + lhs_dtype, rhs_type_context) elif isinstance(lhs_atomicity, AtomicUpdate): codegen_state.seen_atomic_dtypes.add(lhs_dtype) diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index a5f7562c41c3ec8eca673904550e078d2a992241..95299ef525b602713a43fd71c95e1918d8270979 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -507,6 +507,18 @@ class OpenCLCASTBuilder(CASTBuilder): return CLConstant(arg_decl) + # {{{ + + def emit_atomic_init(self, codegen_state, lhs_atomicity, lhs_var, + lhs_expr, rhs_expr, lhs_dtype, rhs_type_context): + # for the CL1 flavor, this is as simple as a regular update with whatever + # the RHS value is... + + return self.emit_atomic_update(codegen_state, lhs_atomicity, lhs_var, + lhs_expr, rhs_expr, lhs_dtype, rhs_type_context) + + # }}} + # {{{ code generation for atomic update def emit_atomic_update(self, codegen_state, lhs_atomicity, lhs_var, diff --git a/test/test_loopy.py b/test/test_loopy.py index 0aff90fd275c59f89adfee6ea0b06feb6d982482..f11230b1f0456a5f534b392915c36f158f4fb321 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1064,6 +1064,25 @@ def test_atomic_load(ctx_factory): assert np.allclose(out, np.full_like(out, (-(2 * n - 1) / float(3 * vec_width)))) +def test_atomic_init(): + dtype = np.float32 + vec_width = 4 + + knl = lp.make_kernel( + "{ [i,j]: 0<=i<100 }", + """ + out[i%4] = 0 {id=init, atomic=init} + """, + [ + lp.GlobalArg("out", dtype, shape=lp.auto, for_atomic=True), + "..." + ], + silenced_warnings=["write_race(init)"]) + knl = lp.split_iname(knl, 'i', vec_width, inner_tag='l.0') + + print(lp.generate_code_v2(knl).device_code()) + + def test_within_inames_and_reduction(): # See https://github.com/inducer/loopy/issues/24