From 8ebcc0663749c199b3d9b27f10abaf920d0d55bc Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 8 Dec 2015 00:04:49 -0600
Subject: [PATCH] Fix CL image arguments

---
 loopy/kernel/data.py     | 2 +-
 loopy/target/__init__.py | 2 +-
 loopy/target/opencl.py   | 4 ++--
 3 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py
index c95ca0e91..9266db0e3 100644
--- a/loopy/kernel/data.py
+++ b/loopy/kernel/data.py
@@ -244,7 +244,7 @@ class ImageArg(ArrayBase, KernelArgument):
 
     def get_arg_decl(self, target, name_suffix, shape, dtype, is_written):
         return target.get_image_arg_decl(self.name + name_suffix, shape,
-                dtype, is_written)
+                self.num_target_axes(), dtype, is_written)
 
 
 class ValueArg(KernelArgument):
diff --git a/loopy/target/__init__.py b/loopy/target/__init__.py
index 17534bf6d..ba83b07ee 100644
--- a/loopy/target/__init__.py
+++ b/loopy/target/__init__.py
@@ -119,7 +119,7 @@ class TargetBase(object):
     def get_global_arg_decl(self, name, shape, dtype, is_written):
         raise NotImplementedError()
 
-    def get_image_arg_decl(self, name, shape, dtype, is_written):
+    def get_image_arg_decl(self, name, shape, num_target_axes, dtype, is_written):
         raise NotImplementedError()
 
     # }}}
diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py
index 4a9b0f31d..9b5c4544b 100644
--- a/loopy/target/opencl.py
+++ b/loopy/target/opencl.py
@@ -309,14 +309,14 @@ class OpenCLTarget(CTarget):
         return CLGlobal(super(OpenCLTarget, self).get_global_arg_decl(
             name, shape, dtype, is_written))
 
-    def get_image_arg_decl(self, name, shape, dtype, is_written):
+    def get_image_arg_decl(self, name, shape, num_target_axes, dtype, is_written):
         if is_written:
             mode = "w"
         else:
             mode = "r"
 
         from cgen.opencl import CLImage
-        return CLImage(self.num_target_axes(), mode, name)
+        return CLImage(num_target_axes, mode, name)
 
     def get_constant_arg_decl(self, name, shape, dtype, is_written):
         from loopy.codegen import POD  # uses the correct complex type
-- 
GitLab