diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 37c8a12ee125fb26c54e84918a538c8d36e0cb4a..225f7e7fec30d38159a93df6ec4df35a45c21bc6 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -561,7 +561,7 @@ def generate_code(kernel, device=None): preamble_generators = (kernel.preamble_generators + kernel.target.preamble_generators()) for prea_gen in preamble_generators: - preambles.extend(prea_gen(kernel.target, seen_dtypes, seen_functions)) + preambles.extend(prea_gen(kernel, seen_dtypes, seen_functions)) seen_preamble_tags = set() dedup_preambles = [] diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index d2d178bc459cb0c231a78a5a1b2c3b8092d07536..13afaa66d05b8dce89a2eb3f1f06e8b752dc5420 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -302,7 +302,7 @@ class LoopKernel(RecordWithoutPickling): manglers = self.target.function_manglers() + self.function_manglers for mangler in manglers: - mangle_result = mangler(self.target, identifier, arg_dtypes) + mangle_result = mangler(self, identifier, arg_dtypes) if mangle_result is not None: return mangle_result @@ -316,7 +316,7 @@ class LoopKernel(RecordWithoutPickling): manglers = self.target.symbol_manglers() + self.symbol_manglers for mangler in manglers: - result = mangler(self.target, identifier) + result = mangler(self, identifier) if result is not None: return result diff --git a/loopy/library/function.py b/loopy/library/function.py index e494169bbe5b83df852ea3d483ed3640381891f6..df623a4770f4f14a7952ee2e0edbf59939de1cfd 100644 --- a/loopy/library/function.py +++ b/loopy/library/function.py @@ -23,19 +23,19 @@ THE SOFTWARE. """ -def default_function_mangler(target, name, arg_dtypes): +def default_function_mangler(kernel, name, arg_dtypes): from loopy.library.reduction import reduction_function_mangler manglers = [reduction_function_mangler] for mangler in manglers: - result = mangler(target, name, arg_dtypes) + result = mangler(kernel, name, arg_dtypes) if result is not None: return result return None -def single_arg_function_mangler(target, name, arg_dtypes): +def single_arg_function_mangler(kernel, name, arg_dtypes): if len(arg_dtypes) == 1: dtype, = arg_dtypes return dtype, name diff --git a/loopy/target/opencl/__init__.py b/loopy/target/opencl/__init__.py index eebe6f5da0b81fa9b4c1ac7b4cda0ba8b1ac283e..fcc6819ba247cf06127d3095e57a55226e2b9a2d 100644 --- a/loopy/target/opencl/__init__.py +++ b/loopy/target/opencl/__init__.py @@ -105,7 +105,7 @@ def _register_vector_types(dtype_registry): # {{{ function mangler -def opencl_function_mangler(target, name, arg_dtypes): +def opencl_function_mangler(kernel, name, arg_dtypes): if not isinstance(name, str): return None @@ -134,7 +134,7 @@ def opencl_function_mangler(target, name, arg_dtypes): # {{{ symbol mangler -def opencl_symbol_mangler(target, name): +def opencl_symbol_mangler(kernel, name): # FIXME: should be more picky about exact names if name.startswith("FLT_"): return np.dtype(np.float32), name @@ -155,7 +155,7 @@ def opencl_symbol_mangler(target, name): # {{{ preamble generator -def opencl_preamble_generator(target, seen_dtypes, seen_functions): +def opencl_preamble_generator(kernel, seen_dtypes, seen_functions): has_double = False for dtype in seen_dtypes: