diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index e4a512957e5cafdb310552399f53b3b8f4274b07..1ddb9ec30c6b4faec3ddf95413bd7ad5d0ded3a9 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -77,9 +77,20 @@ class TypeInferenceMapper(CombineMapper): return result + def map_quotient(self, expr): + n_dtype = self.rec(expr.numerator) + d_dtype = self.rec(expr.denominator) + + if n_dtype.kind in "iu" and d_dtype.kind in "iu": + # both integers + return np.dtype(np.float64) + + else: + return self.combine([n_dtype, d_dtype]) + def map_constant(self, expr): if isinstance(expr, int): - for tp in [np.int8, np.int16, np.int32, np.int64]: + for tp in [np.int32, np.int64]: iinfo = np.iinfo(tp) if iinfo.min <= expr <= iinfo.max: return np.dtype(tp) @@ -547,8 +558,11 @@ class LoopyCCodeMapper(RecursiveMapper): self.seen_functions.add((identifier, c_name, par_dtypes)) if str_parameters is None: + # /!\ FIXME For some functions (e.g. 'sin'), it makes sense to + # propagate the type context here. But for many others, it does + # not. str_parameters = [ - self.rec(par, PREC_NONE, type_context) + self.rec(par, PREC_NONE, type_context=None) for par in expr.parameters] if c_name is None: @@ -649,21 +663,37 @@ class LoopyCCodeMapper(RecursiveMapper): def map_quotient(self, expr, enclosing_prec, type_context): def base_impl(expr, enclosing_prec, type_context, num_tgt_dtype=None): + num = self.rec(expr.numerator, PREC_PRODUCT, type_context, num_tgt_dtype) + + # analogous to ^{-1} + denom = self.rec(expr.denominator, PREC_POWER, type_context) + + if n_dtype.kind not in "fc" and d_dtype.kind not in "fc": + # must both be integers + if type_context == "f": + num = "((float) (%s))" % num + denom = "((float) (%s))" % denom + elif type_context == "f": + num = "((double) (%s))" % num + denom = "((double) (%s))" % denom + return self.parenthesize_if_needed( "%s / %s" % ( # Space is necessary--otherwise '/*' # (i.e. divide-dererference) becomes # start-of-comment in C. - self.rec(expr.numerator, PREC_PRODUCT, type_context, num_tgt_dtype), - # analogous to ^{-1} - self.rec(expr.denominator, PREC_POWER, type_context)), + num, + denom), enclosing_prec, PREC_PRODUCT) + n_dtype = self.infer_type(expr.numerator) + d_dtype = self.infer_type(expr.denominator) + if not self.allow_complex: return base_impl(expr, enclosing_prec, type_context) - n_complex = 'c' == self.infer_type(expr.numerator).kind - d_complex = 'c' == self.infer_type(expr.denominator).kind + n_complex = 'c' == n_dtype.kind + d_complex = 'c' == d_dtype.kind tgt_dtype = self.infer_type(expr)