diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index caee73eb1c3320f03ceac66e55e8f5c0bfadbbc2..c111a02b75243b10de90b2d18d62e3759c575fa8 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -525,11 +525,17 @@ class ExpressionToCExpressionMapper(IdentityMapper): real_sum = p.flattened_sum([self.rec(r, type_context) for r in reals]) - complex_sum = self.rec(complexes[0], type_context, tgt_dtype) - for child in complexes[1:]: - complex_sum = var("%s_add" % tgt_name)( - complex_sum, - self.rec(child, type_context, tgt_dtype)) + c_applied = [self.rec(c, type_context, tgt_dtype) for c in complexes] + + def binary_tree_add(start, end): + if start + 1 == end: + return c_applied[start] + mid = (start + end)//2 + lsum = binary_tree_add(start, mid) + rsum = binary_tree_add(mid, end) + return var("%s_add" % tgt_name)(lsum, rsum) + + complex_sum = binary_tree_add(0, len(c_applied)) if real_sum: return var("%s_radd" % tgt_name)(real_sum, complex_sum) @@ -569,11 +575,17 @@ class ExpressionToCExpressionMapper(IdentityMapper): real_prd = p.flattened_product( [self.rec(r, type_context) for r in reals]) - complex_prd = self.rec(complexes[0], type_context, tgt_dtype) - for child in complexes[1:]: - complex_prd = var("%s_mul" % tgt_name)( - complex_prd, - self.rec(child, type_context, tgt_dtype)) + c_applied = [self.rec(c, type_context, tgt_dtype) for c in complexes] + + def binary_tree_mul(start, end): + if start + 1 == end: + return c_applied[start] + mid = (start + end)//2 + lsum = binary_tree_mul(start, mid) + rsum = binary_tree_mul(mid, end) + return var("%s_mul" % tgt_name)(lsum, rsum) + + complex_prd = binary_tree_mul(0, len(complexes)) if real_prd: return var("%s_rmul" % tgt_name)(real_prd, complex_prd)