diff --git a/MEMO b/MEMO index 15d4ae95b84e177a93bfed0576158b9c832a8364..62bfa7f35464caffb570b87132241acdc72ca0c3 100644 --- a/MEMO +++ b/MEMO @@ -46,6 +46,8 @@ To-do - Fix timer / call code +- use memory pools for arrays + Fixes: - Group instructions by dependency/inames for scheduling, to diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index deb20098f5356dfd989ab85d86ed244dc23ca975..73ed05b43475f8960e6c1271046b6f06fd3d8b16 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -49,7 +49,26 @@ class TypeInferenceMapper(CombineMapper): return result def map_constant(self, expr): - return np.asarray(expr).dtype + if isinstance(expr, int): + for tp in [np.int8, np.int16, np.int32, np.int64]: + iinfo = np.iinfo(tp) + if iinfo.min <= expr <= iinfo.max: + return np.dtype(tp) + + else: + raise TypeInferenceFailure("integer constant '%s' too large" % expr) + + dt = np.asarray(expr).dtype + if dt.kind == "f": + # deduce the smaller type by default + return np.dtype(np.float32) + elif dt.kind == "f": + # deduce the smaller type by default + return np.dtype(np.complex64) + elif hasattr(expr, "dtype"): + return expr.dtype + else: + raise TypeInferenceFailure("cannot deduce type of constant '%s'" % expr) def map_subscript(self, expr): return self.rec(expr.aggregate) diff --git a/test/test_loopy.py b/test/test_loopy.py index 6e73e82ef0d78975ec4a74d2a71c73f5b68b6378..595a263e7abf1d3683134cc4245a965902a69cb1 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -14,6 +14,30 @@ __all__ = ["pytest_generate_tests", +def test_type_inference_no_artificial_doubles(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel(ctx.devices[0], + "{[i]: 0<=i bb = a[i] - b[i] + c[i] = bb + """, + [ + lp.GlobalArg("a", np.float32, shape=("n",)), + lp.GlobalArg("b", np.float32, shape=("n",)), + lp.GlobalArg("c", np.float32, shape=("n",)), + lp.ValueArg("n", np.int32), + ], + assumptions="n>=1") + + for k in lp.generate_loop_schedules(knl): + code = lp.generate_code(k) + assert "double" not in code + + + + def test_simple_side_effect(ctx_factory): ctx = ctx_factory()