diff --git a/loopy/expression.py b/loopy/expression.py index e4c89c735d3446828429f8de3a8daad940e81e06..d912d528a7cc6a567db5aca684103eb736737030 100644 --- a/loopy/expression.py +++ b/loopy/expression.py @@ -1,4 +1,4 @@ -from __future__ import division, absolute_import +from __future__ import division, absolute_import, print_function __copyright__ = "Copyright (C) 2012-15 Andreas Kloeckner" @@ -79,7 +79,7 @@ class TypeInferenceMapper(CombineMapper): while dtypes: other = dtypes.pop() - if result.isbuiltin and other.isbuiltin: + if result.fields is None and other.fields is None: if (result, other) in [ (np.int32, np.float32), (np.float32, np.int32)]: # numpy makes this a double. I disagree. @@ -89,11 +89,14 @@ class TypeInferenceMapper(CombineMapper): np.empty(0, dtype=result) + np.empty(0, dtype=other) ).dtype - elif result.isbuiltin and not other.isbuiltin: + + elif result.fields is None and other.fields is not None: # assume the non-native type takes over + # (This is used for vector types.) result = other - elif not result.isbuiltin and other.isbuiltin: + elif result.fields is not None and other.fields is None: # assume the non-native type takes over + # (This is used for vector types.) pass else: if result is not other: diff --git a/test/test_statistics.py b/test/test_statistics.py index 5bcd641e51ef284c31ac61b24664a1b26d5e7fd1..6867ca28f696b4e9de827ecb340698ada6e8cfa9 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -137,18 +137,23 @@ def test_op_counter_bitwise(): ], name="bitwise", assumptions="n,m,l >= 1") - knl = lp.add_and_infer_dtypes(knl, - dict( - a=np.int32, b=np.int32, - g=np.int64, h=np.int64)) + knl = lp.add_and_infer_dtypes( + knl, dict( + a=np.int32, b=np.int32, + g=np.int64, h=np.int64)) + poly = get_op_poly(knl) + n = 10 m = 10 l = 10 - i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l}) + param_values = {'n': n, 'm': m, 'l': l} + i32 = poly.dict[np.dtype(np.int32)].eval_with_dict(param_values) + i64 = poly.dict[np.dtype(np.int64)].eval_with_dict(param_values) + not_there = poly[np.dtype(np.float64)].eval_with_dict(param_values) print(poly.dict) - not_there = poly[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l}) - assert i32 == 3*n*m + n*m*l + assert i32 == n*m + n*m*l + assert i64 == 2*n*m assert not_there == 0