diff --git a/sumpy/codegen.py b/sumpy/codegen.py index cf1f87d02917cb656c78e93d743395e9dace043f..8a5f7df43a06f992f6cf82d04006a3ebf54a9edd 100644 --- a/sumpy/codegen.py +++ b/sumpy/codegen.py @@ -33,6 +33,8 @@ import re from pymbolic.mapper import IdentityMapper, WalkMapper, CSECachingMapperMixin import pymbolic.primitives as prim +from loopy.types import NumpyType + from pytools import memoize_method from pymbolic.sympy_interface import ( @@ -160,22 +162,30 @@ def bessel_mangler(kernel, identifier, arg_dtypes): raise NotImplementedError("Only the PyOpenCLTarget is supported as of now") if identifier == "hank1_01": - if arg_dtypes[0].kind == "c": + if arg_dtypes[0].is_complex(): identifier = "hank1_01_complex" - return (np.dtype(hank1_01_result_dtype), - identifier, (np.dtype(np.complex128),)) + return (NumpyType(np.dtype(hank1_01_result_dtype)), + identifier, ( + NumpyType(np.dtype(np.complex128)), + )) else: - return (np.dtype(hank1_01_result_dtype), - identifier, (np.dtype(np.float64),)) + return (NumpyType(np.dtype(hank1_01_result_dtype)), + identifier, ( + NumpyType(np.dtype(np.float64)), + )) elif identifier == "bessel_jv_two": - if arg_dtypes[1].kind == "c": + if arg_dtypes[1].is_complex(): identifier = "bessel_jv_two_complex" - return (np.dtype(bessel_j_two_result_dtype), - identifier, (np.dtype(np.int32), np.dtype(np.complex128),)) + return (NumpyType(np.dtype(bessel_j_two_result_dtype)), + identifier, ( + NumpyType(np.dtype(np.int32)), + NumpyType(np.dtype(np.complex128)),)) else: - return (np.dtype(bessel_j_two_result_dtype), - identifier, (np.dtype(np.int32), np.dtype(np.float64),)) + return (NumpyType(np.dtype(bessel_j_two_result_dtype)), + identifier, ( + NumpyType(np.dtype(np.int32)), + NumpyType(np.dtype(np.float64)),)) else: return None diff --git a/sumpy/version.py b/sumpy/version.py index 4765172e125c186555ee5d5730cf03295fcd7d0f..04406292abfa5ab2d452b6c6e0a7ae210f166a26 100644 --- a/sumpy/version.py +++ b/sumpy/version.py @@ -25,4 +25,4 @@ VERSION = (2016, 1) VERSION_STATUS = "beta1" VERSION_TEXT = ".".join(str(x) for x in VERSION) + VERSION_STATUS -KERNEL_VERSION = 2 +KERNEL_VERSION = 3