From e7ccefc67b47401f98dc90639bc0e4a1884112e4 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 4 Oct 2012 19:59:48 -0400 Subject: [PATCH] Make type inference of 'float-float' not come out to double. --- MEMO | 2 ++ loopy/codegen/expression.py | 21 ++++++++++++++++++++- test/test_loopy.py | 24 ++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/MEMO b/MEMO index 15d4ae95b..62bfa7f35 100644 --- a/MEMO +++ b/MEMO @@ -46,6 +46,8 @@ To-do - Fix timer / call code +- use memory pools for arrays + Fixes: - Group instructions by dependency/inames for scheduling, to diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index deb20098f..73ed05b43 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -49,7 +49,26 @@ class TypeInferenceMapper(CombineMapper): return result def map_constant(self, expr): - return np.asarray(expr).dtype + if isinstance(expr, int): + for tp in [np.int8, np.int16, np.int32, np.int64]: + iinfo = np.iinfo(tp) + if iinfo.min <= expr <= iinfo.max: + return np.dtype(tp) + + else: + raise TypeInferenceFailure("integer constant '%s' too large" % expr) + + dt = np.asarray(expr).dtype + if dt.kind == "f": + # deduce the smaller type by default + return np.dtype(np.float32) + elif dt.kind == "f": + # deduce the smaller type by default + return np.dtype(np.complex64) + elif hasattr(expr, "dtype"): + return expr.dtype + else: + raise TypeInferenceFailure("cannot deduce type of constant '%s'" % expr) def map_subscript(self, expr): return self.rec(expr.aggregate) diff --git a/test/test_loopy.py b/test/test_loopy.py index 6e73e82ef..595a263e7 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -14,6 +14,30 @@ __all__ = ["pytest_generate_tests", +def test_type_inference_no_artificial_doubles(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel(ctx.devices[0], + "{[i]: 0<=i bb = a[i] - b[i] + c[i] = bb + """, + [ + lp.GlobalArg("a", np.float32, shape=("n",)), + lp.GlobalArg("b", np.float32, shape=("n",)), + lp.GlobalArg("c", np.float32, shape=("n",)), + lp.ValueArg("n", np.int32), + ], + assumptions="n>=1") + + for k in lp.generate_loop_schedules(knl): + code = lp.generate_code(k) + assert "double" not in code + + + + def test_simple_side_effect(ctx_factory): ctx = ctx_factory() -- GitLab