diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 20132276edece9822ecf1fc2d94dedc4691b4b7f..fd39956c1a6fd0666b20ffee3c672da6190f3a03 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -187,7 +187,8 @@ class POD(PODBase): # {{{ main code generation entrypoint -def generate_code(kernel, with_annotation=False): +def generate_code(kernel, with_annotation=False, + allow_complex=False): from cgen import (FunctionBody, FunctionDeclaration, Value, ArrayOf, Module, Block, Line, Const, LiteralLines, Initializer) @@ -196,8 +197,9 @@ def generate_code(kernel, with_annotation=False): CLLocal, CLImage, CLConstant) from loopy.codegen.expression import LoopyCCodeMapper - ccm = (LoopyCCodeMapper(kernel, with_annotation=with_annotation) - .copy_and_assign_many(make_initial_assignments(kernel))) + ccm = (LoopyCCodeMapper(kernel, with_annotation=with_annotation, + allow_complex=allow_complex) + .copy_and_assign_many(make_initial_assignments(kernel))) mod = Module() diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index fc9e8a3b7ab3c3423d682e85602910c6e077623f..b871721d2730abc787ff6c631cf97da1c52563b0 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -62,7 +62,7 @@ class TypeInferenceMapper(CombineMapper): class LoopyCCodeMapper(CCodeMapper): def __init__(self, kernel, cse_name_list=[], var_subst_map={}, - with_annotation=False, ): + with_annotation=False, allow_complex=False): def constant_mapper(c): if isinstance(c, float): # FIXME: type-variable @@ -74,6 +74,7 @@ class LoopyCCodeMapper(CCodeMapper): cse_name_list=cse_name_list) self.kernel = kernel self.infer_type = TypeInferenceMapper(kernel) + self.allow_complex = allow_complex self.with_annotation = with_annotation self.var_subst_map = var_subst_map.copy() @@ -85,7 +86,8 @@ class LoopyCCodeMapper(CCodeMapper): cse_name_list = self.cse_name_list return LoopyCCodeMapper(self.kernel, cse_name_list=cse_name_list, var_subst_map=var_subst_map, - with_annotation=self.with_annotation) + with_annotation=self.with_annotation, + allow_complex=self.allow_complex) def copy_and_assign(self, name, value): """Make a copy of self with variable *name* fixed to *value*.""" @@ -214,6 +216,9 @@ class LoopyCCodeMapper(CCodeMapper): raise RuntimeError def map_sum(self, expr, enclosing_prec): + if not self.allow_complex: + return CCodeMapper.map_sum(self, expr, enclosing_prec) + tgt_dtype = self.infer_type(expr) is_complex = tgt_dtype.kind == 'c' @@ -239,6 +244,9 @@ class LoopyCCodeMapper(CCodeMapper): return self.parenthesize_if_needed(result, enclosing_prec, PREC_SUM) def map_product(self, expr, enclosing_prec): + if not self.allow_complex: + return CCodeMapper.map_product(self, expr, enclosing_prec) + tgt_dtype = self.infer_type(expr) is_complex = 'c' == tgt_dtype.kind @@ -270,6 +278,9 @@ class LoopyCCodeMapper(CCodeMapper): return self.parenthesize_if_needed(result, enclosing_prec, PREC_PRODUCT) def map_quotient(self, expr, enclosing_prec): + if not self.allow_complex: + return CCodeMapper.map_quotient(self, expr, enclosing_prec) + n_complex = 'c' == self.infer_type(expr.numerator).kind d_complex = 'c' == self.infer_type(expr.denominator).kind @@ -292,6 +303,9 @@ class LoopyCCodeMapper(CCodeMapper): self.rec(expr.denominator, PREC_NONE)) def map_remainder(self, expr, enclosing_prec): + if not self.allow_complex: + return CCodeMapper.map_remainder(self, expr, enclosing_prec) + tgt_dtype = self.infer_type(expr) if 'c' == tgt_dtype.kind: raise RuntimeError("complex remainder not defined") @@ -299,6 +313,9 @@ class LoopyCCodeMapper(CCodeMapper): return CCodeMapper.map_remainder(self, expr, enclosing_prec) def map_power(self, expr, enclosing_prec): + if not self.allow_complex: + return CCodeMapper.map_power(self, expr, enclosing_prec) + from pymbolic.mapper.stringifier import PREC_NONE tgt_dtype = self.infer_type(expr)