diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index 1ddb9ec30c6b4faec3ddf95413bd7ad5d0ded3a9..5772bf7bcdf01245e496fca28e820c8457dcfc22 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -77,6 +77,24 @@ class TypeInferenceMapper(CombineMapper): return result + def map_sum(self, expr): + dtypes = [] + small_integer_dtypes = [] + for child in expr.children: + dtype = self.rec(child) + if isinstance(child, (int, np.integer)) and abs(child) < 1024: + small_integer_dtypes.append(dtype) + else: + dtypes.append(dtype) + + from pytools import all + if all(dtype.kind == "i" for dtype in dtypes): + dtypes.extend(small_integer_dtypes) + + return self.combine(dtypes) + + map_product = map_sum + def map_quotient(self, expr): n_dtype = self.rec(expr.numerator) d_dtype = self.rec(expr.denominator)