From fa3d2344176d10ac4ab111b8eca1adc3c86da900 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 21 Mar 2016 15:20:35 -0500 Subject: [PATCH] Fix Bessel function mangler for loopy type system changes --- sumpy/codegen.py | 30 ++++++++++++++++++++---------- sumpy/version.py | 2 +- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/sumpy/codegen.py b/sumpy/codegen.py index cf1f87d0..8a5f7df4 100644 --- a/sumpy/codegen.py +++ b/sumpy/codegen.py @@ -33,6 +33,8 @@ import re from pymbolic.mapper import IdentityMapper, WalkMapper, CSECachingMapperMixin import pymbolic.primitives as prim +from loopy.types import NumpyType + from pytools import memoize_method from pymbolic.sympy_interface import ( @@ -160,22 +162,30 @@ def bessel_mangler(kernel, identifier, arg_dtypes): raise NotImplementedError("Only the PyOpenCLTarget is supported as of now") if identifier == "hank1_01": - if arg_dtypes[0].kind == "c": + if arg_dtypes[0].is_complex(): identifier = "hank1_01_complex" - return (np.dtype(hank1_01_result_dtype), - identifier, (np.dtype(np.complex128),)) + return (NumpyType(np.dtype(hank1_01_result_dtype)), + identifier, ( + NumpyType(np.dtype(np.complex128)), + )) else: - return (np.dtype(hank1_01_result_dtype), - identifier, (np.dtype(np.float64),)) + return (NumpyType(np.dtype(hank1_01_result_dtype)), + identifier, ( + NumpyType(np.dtype(np.float64)), + )) elif identifier == "bessel_jv_two": - if arg_dtypes[1].kind == "c": + if arg_dtypes[1].is_complex(): identifier = "bessel_jv_two_complex" - return (np.dtype(bessel_j_two_result_dtype), - identifier, (np.dtype(np.int32), np.dtype(np.complex128),)) + return (NumpyType(np.dtype(bessel_j_two_result_dtype)), + identifier, ( + NumpyType(np.dtype(np.int32)), + NumpyType(np.dtype(np.complex128)),)) else: - return (np.dtype(bessel_j_two_result_dtype), - identifier, (np.dtype(np.int32), np.dtype(np.float64),)) + return (NumpyType(np.dtype(bessel_j_two_result_dtype)), + identifier, ( + NumpyType(np.dtype(np.int32)), + NumpyType(np.dtype(np.float64)),)) else: return None diff --git a/sumpy/version.py b/sumpy/version.py index 4765172e..04406292 100644 --- a/sumpy/version.py +++ b/sumpy/version.py @@ -25,4 +25,4 @@ VERSION = (2016, 1) VERSION_STATUS = "beta1" VERSION_TEXT = ".".join(str(x) for x in VERSION) + VERSION_STATUS -KERNEL_VERSION = 2 +KERNEL_VERSION = 3 -- GitLab