From 67ad6af03fa2d6632efc4d1d2282cc9b6e4ee236 Mon Sep 17 00:00:00 2001 From: tj-sun Date: Mon, 27 Nov 2017 16:08:23 +0000 Subject: [PATCH 1/2] add manglers for CTarget to handle min, max, NAN --- loopy/target/c/__init__.py | 62 ++++++++++++++++++++++++++++++++++++++ loopy/target/opencl.py | 24 --------------- test/test_target.py | 26 ++++++++++++++++ 3 files changed, 88 insertions(+), 24 deletions(-) diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index 832c224f3..4f4ac4f31 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -27,12 +27,14 @@ THE SOFTWARE. import six import numpy as np # noqa +from loopy.kernel.data import CallMangleInfo from loopy.target import TargetBase, ASTBuilderBase, DummyHostASTBuilder from loopy.diagnostic import LoopyError from cgen import Pointer, NestedDeclarator, Block from cgen.mapper import IdentityMapper as CASTIdentityMapperBase from pymbolic.mapper.stringifier import PREC_NONE from loopy.symbolic import IdentityMapper +from loopy.types import NumpyType import pymbolic.primitives as p from pytools import memoize_method @@ -313,9 +315,69 @@ class _ConstPointer(Pointer): return sub_tp, ("*const %s" % sub_decl) +# {{{ symbol mangler + +def c_symbol_mangler(kernel, name): + # float NAN as defined in C99 standard + if name == "NAN": + return NumpyType(np.dtype(np.float32)), name + return None + +# }}} + + +# {{{ function mangler + +def c_function_mangler(target, name, arg_dtypes): + # convert abs(), min(), max() to fabs(), fmin(), fmax() to comply with + # C99 standard + if not isinstance(name, str): + return None + + if (name == "abs" + and len(arg_dtypes) == 1 + and arg_dtypes[0].numpy_dtype.kind == "f"): + return CallMangleInfo( + target_name="fabs", + result_dtypes=arg_dtypes, + arg_dtypes=arg_dtypes) + + if name in ["max", "min"] and len(arg_dtypes) == 2: + dtype = np.find_common_type( + [], [dtype.numpy_dtype for dtype in arg_dtypes]) + + if dtype.kind == "c": + raise RuntimeError("min/max do not support complex numbers") + + if dtype.kind == "f": + name = "f" + name + + result_dtype = NumpyType(dtype) + return CallMangleInfo( + target_name=name, + result_dtypes=(result_dtype,), + arg_dtypes=2*(result_dtype,)) + + return None + +# }}} + + class CASTBuilder(ASTBuilderBase): # {{{ library + def function_manglers(self): + return ( + super(CASTBuilder, self).function_manglers() + [ + c_function_mangler + ]) + + def symbol_manglers(self): + return ( + super(CASTBuilder, self).symbol_manglers() + [ + c_symbol_mangler + ]) + def preamble_generators(self): return ( super(CASTBuilder, self).preamble_generators() + [ diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 50d6acc7a..2763caace 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -167,30 +167,6 @@ def opencl_function_mangler(kernel, name, arg_dtypes): if not isinstance(name, str): return None - if (name == "abs" - and len(arg_dtypes) == 1 - and arg_dtypes[0].numpy_dtype.kind == "f"): - return CallMangleInfo( - target_name="fabs", - result_dtypes=arg_dtypes, - arg_dtypes=arg_dtypes) - - if name in ["max", "min"] and len(arg_dtypes) == 2: - dtype = np.find_common_type( - [], [dtype.numpy_dtype for dtype in arg_dtypes]) - - if dtype.kind == "c": - raise RuntimeError("min/max do not support complex numbers") - - if dtype.kind == "f": - name = "f" + name - - result_dtype = NumpyType(dtype) - return CallMangleInfo( - target_name=name, - result_dtypes=(result_dtype,), - arg_dtypes=2*(result_dtype,)) - if name == "dot": scalar_dtype, offset, field_name = arg_dtypes[0].numpy_dtype.fields["s0"] return CallMangleInfo( diff --git a/test/test_target.py b/test/test_target.py index 01a2e5d9d..ff8774b76 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -140,6 +140,32 @@ def test_generate_c_snippet(): print(lp.generate_body(knl)) +def test_c_min_max(): + # Test fmin() fmax() is generated for C backend instead of max() and min() + from loopy.target.c import CTarget + import pymbolic.primitives as p + i = p.Variable("i") + xi = p.Subscript(p.Variable("x"), i) + yi = p.Subscript(p.Variable("y"), i) + zi = p.Subscript(p.Variable("z"), i) + + N = 100 + domain = "{[i]: 0<=i<%d}" % N + data = [lp.GlobalArg("x", np.float64, shape=(N,)), + lp.GlobalArg("y", np.float64, shape=(N,)), + lp.GlobalArg("z", np.float64, shape=(N,))] + + inst = [lp.Assignment(xi, p.Variable("min")(yi, zi))] + knl = lp.make_kernel(domain, inst, data, target=CTarget()) + code = lp.generate_code_v2(knl).device_code() + assert "fmin" in code + + inst = [lp.Assignment(xi, p.Variable("max")(yi, zi))] + knl = lp.make_kernel(domain, inst, data, target=CTarget()) + code = lp.generate_code_v2(knl).device_code() + assert "fmax" in code + + @pytest.mark.parametrize("tp", ["f32", "f64"]) def test_random123(ctx_factory, tp): ctx = ctx_factory() -- GitLab From ee70134a2843fa94436f0ac0e5134c7c5b902f53 Mon Sep 17 00:00:00 2001 From: tj-sun Date: Mon, 27 Nov 2017 16:26:34 +0000 Subject: [PATCH 2/2] flake8 --- test/test_target.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_target.py b/test/test_target.py index ff8774b76..aa6f00463 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -149,11 +149,11 @@ def test_c_min_max(): yi = p.Subscript(p.Variable("y"), i) zi = p.Subscript(p.Variable("z"), i) - N = 100 - domain = "{[i]: 0<=i<%d}" % N - data = [lp.GlobalArg("x", np.float64, shape=(N,)), - lp.GlobalArg("y", np.float64, shape=(N,)), - lp.GlobalArg("z", np.float64, shape=(N,))] + n = 100 + domain = "{[i]: 0<=i<%d}" % n + data = [lp.GlobalArg("x", np.float64, shape=(n,)), + lp.GlobalArg("y", np.float64, shape=(n,)), + lp.GlobalArg("z", np.float64, shape=(n,))] inst = [lp.Assignment(xi, p.Variable("min")(yi, zi))] knl = lp.make_kernel(domain, inst, data, target=CTarget()) -- GitLab