From d316087d30c7c6891fa23b600c71228416223fe3 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 23 Jun 2015 14:47:26 -0500 Subject: [PATCH] Make type inference deterministic dtype.isbuiltin can apparently be both True and False for simple types like np.int32, depending on how the dtype is constructed. That makes it unsuited to detect user-created dtypes. --- loopy/expression.py | 11 +++++++---- test/test_statistics.py | 19 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/loopy/expression.py b/loopy/expression.py index e4c89c735..d912d528a 100644 --- a/loopy/expression.py +++ b/loopy/expression.py @@ -1,4 +1,4 @@ -from __future__ import division, absolute_import +from __future__ import division, absolute_import, print_function __copyright__ = "Copyright (C) 2012-15 Andreas Kloeckner" @@ -79,7 +79,7 @@ class TypeInferenceMapper(CombineMapper): while dtypes: other = dtypes.pop() - if result.isbuiltin and other.isbuiltin: + if result.fields is None and other.fields is None: if (result, other) in [ (np.int32, np.float32), (np.float32, np.int32)]: # numpy makes this a double. I disagree. @@ -89,11 +89,14 @@ class TypeInferenceMapper(CombineMapper): np.empty(0, dtype=result) + np.empty(0, dtype=other) ).dtype - elif result.isbuiltin and not other.isbuiltin: + + elif result.fields is None and other.fields is not None: # assume the non-native type takes over + # (This is used for vector types.) result = other - elif not result.isbuiltin and other.isbuiltin: + elif result.fields is not None and other.fields is None: # assume the non-native type takes over + # (This is used for vector types.) pass else: if result is not other: diff --git a/test/test_statistics.py b/test/test_statistics.py index 5bcd641e5..6867ca28f 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -137,18 +137,23 @@ def test_op_counter_bitwise(): ], name="bitwise", assumptions="n,m,l >= 1") - knl = lp.add_and_infer_dtypes(knl, - dict( - a=np.int32, b=np.int32, - g=np.int64, h=np.int64)) + knl = lp.add_and_infer_dtypes( + knl, dict( + a=np.int32, b=np.int32, + g=np.int64, h=np.int64)) + poly = get_op_poly(knl) + n = 10 m = 10 l = 10 - i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l}) + param_values = {'n': n, 'm': m, 'l': l} + i32 = poly.dict[np.dtype(np.int32)].eval_with_dict(param_values) + i64 = poly.dict[np.dtype(np.int64)].eval_with_dict(param_values) + not_there = poly[np.dtype(np.float64)].eval_with_dict(param_values) print(poly.dict) - not_there = poly[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l}) - assert i32 == 3*n*m + n*m*l + assert i32 == n*m + n*m*l + assert i64 == 2*n*m assert not_there == 0 -- GitLab