diff --git a/loopy/target/c/c_execution.py b/loopy/target/c/c_execution.py index d73912460a2c99075c875375056be5922a98d692..f664e3ee1d1e8544edb218e93acc0b4ee32d1112 100644 --- a/loopy/target/c/c_execution.py +++ b/loopy/target/c/c_execution.py @@ -48,9 +48,12 @@ class CExecutionWrapperGenerator(ExecutionWrapperGeneratorBase): system_args = ["_lpy_c_kernels"] super().__init__(system_args) - def python_dtype_str(self, dtype): + def python_dtype_str_inner(self, dtype): if np.dtype(str(dtype)).isbuiltin: - return "_lpy_np."+dtype.name + name = dtype.name + if dtype.name == "bool": + name = "bool8" + return f"_lpy_np.dtype(_lpy_np.{name})" raise Exception(f"dtype: {dtype} not recognized") # {{{ handle non numpy arguements @@ -100,7 +103,7 @@ class CExecutionWrapperGenerator(ExecutionWrapperGeneratorBase): name=arg.name, shape=strify(sym_shape), dtype=self.python_dtype_str( - kernel_arg.dtype.numpy_dtype), + gen, kernel_arg.dtype.numpy_dtype), order=order)) expected_strides = tuple( diff --git a/loopy/target/execution.py b/loopy/target/execution.py index da5a54e6b4eafe39117b8b01bba2fdf570f3be8b..c00cc28d5fb58da7cc0ae6c1c0c58be048f546d4 100644 --- a/loopy/target/execution.py +++ b/loopy/target/execution.py @@ -22,6 +22,7 @@ THE SOFTWARE. import numpy as np +from abc import ABC, abstractmethod from pytools import ImmutableRecord, memoize_method from loopy.diagnostic import LoopyError from pytools.py_codegen import ( @@ -119,7 +120,7 @@ class SeparateArrayPackingController: # {{{ ExecutionWrapperGeneratorBase -class ExecutionWrapperGeneratorBase: +class ExecutionWrapperGeneratorBase(ABC): """ A set of common methods for generating a wrapper for execution @@ -129,8 +130,25 @@ class ExecutionWrapperGeneratorBase: def __init__(self, system_args): self.system_args = system_args[:] - def python_dtype_str(self, dtype): - raise NotImplementedError() + from pytools import UniqueNameGenerator + self.dtype_name_generator = UniqueNameGenerator(forced_prefix="_lpy_dtype_") + self.dtype_str_to_name = {} + + @abstractmethod + def python_dtype_str_inner(self, dtype): + pass + + def python_dtype_str(self, gen, numpy_dtype): + dtype_str = self.python_dtype_str_inner(numpy_dtype) + try: + return self.dtype_str_to_name[dtype_str] + except KeyError: + pass + + dtype_name = self.dtype_name_generator() + gen.add_to_preamble(f"{dtype_name} = _lpy_np.dtype({dtype_str})") + self.dtype_str_to_name[dtype_str] = dtype_name + return dtype_name # {{{ invoker generation @@ -381,7 +399,7 @@ class ExecutionWrapperGeneratorBase: expect_no_more_arguments = False - for arg_idx, arg in enumerate(implemented_data_info): + for arg in implemented_data_info: is_written = arg.base_name in kernel.get_written_variables() kernel_arg = kernel.impl_arg_to_arg.get(arg.name) @@ -463,7 +481,7 @@ class ExecutionWrapperGeneratorBase: with Indentation(gen): gen("if %s.dtype != %s:" % (arg.name, self.python_dtype_str( - kernel_arg.dtype.numpy_dtype))) + gen, kernel_arg.dtype.numpy_dtype))) with Indentation(gen): gen("raise TypeError(\"dtype mismatch on argument '%s' " '(got: %%s, expected: %s)" %% %s.dtype)' diff --git a/loopy/target/pyopencl_execution.py b/loopy/target/pyopencl_execution.py index d23301077f1f14f85f93074f067b09acf1faa95a..0a9bafde9608624e7285363e4338f96b551307ea 100644 --- a/loopy/target/pyopencl_execution.py +++ b/loopy/target/pyopencl_execution.py @@ -48,13 +48,14 @@ class PyOpenCLExecutionWrapperGenerator(ExecutionWrapperGeneratorBase): "out_host=None" ] super().__init__(system_args) - from pytools import UniqueNameGenerator - self.dtype_name_generator = UniqueNameGenerator(forced_prefix="_lpy_dtype_") - def python_dtype_str(self, dtype): + def python_dtype_str_inner(self, dtype): import pyopencl.tools as cl_tools if dtype.isbuiltin: - return "_lpy_np."+dtype.name + name = dtype.name + if dtype.name == "bool": + name = "bool8" + return f"_lpy_np.dtype(_lpy_np.{name})" else: return ('_lpy_cl_tools.get_or_register_dtype("%s")' % cl_tools.dtype_to_ctype(dtype)) @@ -117,10 +118,8 @@ class PyOpenCLExecutionWrapperGenerator(ExecutionWrapperGeneratorBase): gen("_lpy_size = %s" % strify(size_expr)) sym_strides = tuple(itemsize*s_i for s_i in sym_ustrides) - dtype_str = self.python_dtype_str(kernel_arg.dtype.numpy_dtype) - dtype_name = self.dtype_name_generator() - gen.add_to_preamble(f"{dtype_name} = _lpy_np.dtype({dtype_str})") + dtype_name = self.python_dtype_str(gen, kernel_arg.dtype.numpy_dtype) gen(f"{arg.name} = _lpy_cl_array.Array(None, {strify(sym_shape)}, " f"{dtype_name}, strides={strify(sym_strides)}, " f"data=allocator({strify(itemsize * var('_lpy_size'))}), "