Skip to content
Snippets Groups Projects
Commit 91be1b94 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Improve code generation for constants, 'f' vs no trailing 'f', integer vs non-integer.

parent 76de8414
No related branches found
No related tags found
No related merge requests found
...@@ -65,7 +65,6 @@ To-do ...@@ -65,7 +65,6 @@ To-do
- Scalar insn priority - Scalar insn priority
- What to do about constants in codegen? (...f suffix, complex types)
- If finding a maximum proves troublesome, move parameters into the domain - If finding a maximum proves troublesome, move parameters into the domain
...@@ -123,6 +122,9 @@ Future ideas ...@@ -123,6 +122,9 @@ Future ideas
Dealt with Dealt with
^^^^^^^^^^ ^^^^^^^^^^
- What to do about constants in codegen? (...f suffix, complex types)
-> dealt with by type contexts
- relating to Multi-Domain - relating to Multi-Domain
- Make sure that variables that enter into loop bounds are only written - Make sure that variables that enter into loop bounds are only written
exactly once. [DONE] exactly once. [DONE]
......
...@@ -173,7 +173,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt): ...@@ -173,7 +173,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
from pymbolic import var from pymbolic import var
rhs += iname_coeff*var(iname) rhs += iname_coeff*var(iname)
end_conds.append("%s >= 0" % end_conds.append("%s >= 0" %
ccm(cfm(rhs))) ccm(cfm(rhs), 'i'))
else: # iname_coeff > 0 else: # iname_coeff > 0
kind, bound = solve_constraint_for_bound(cns, iname) kind, bound = solve_constraint_for_bound(cns, iname)
assert kind == ">=" assert kind == ">="
...@@ -205,7 +205,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt): ...@@ -205,7 +205,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
from cgen import For from cgen import For
from loopy.codegen import wrap_in from loopy.codegen import wrap_in
return wrap_in(For, return wrap_in(For,
"int %s = %s" % (iname, ccm(start_expr)), "int %s = %s" % (iname, ccm(start_expr, 'i')),
" && ".join(end_conds), " && ".join(end_conds),
"++%s" % iname, "++%s" % iname,
stmt) stmt)
......
...@@ -2,8 +2,9 @@ from __future__ import division ...@@ -2,8 +2,9 @@ from __future__ import division
import numpy as np import numpy as np
from pymbolic.mapper.c_code import CCodeMapper as CCodeMapper from pymbolic.mapper import RecursiveMapper
from pymbolic.mapper.stringifier import PREC_NONE from pymbolic.mapper.stringifier import (PREC_NONE, PREC_CALL, PREC_PRODUCT,
PREC_POWER)
from pymbolic.mapper import CombineMapper from pymbolic.mapper import CombineMapper
# {{{ type inference # {{{ type inference
...@@ -57,7 +58,7 @@ class TypeInferenceMapper(CombineMapper): ...@@ -57,7 +58,7 @@ class TypeInferenceMapper(CombineMapper):
if isinstance(identifier, Variable): if isinstance(identifier, Variable):
identifier = identifier.name identifier = identifier.name
arg_dtypes = tuple(self.rec(par) for par in expr.parameters) arg_dtypes = tuple(self.rec(par, None) for par in expr.parameters)
mangle_result = self.kernel.mangle_function(identifier, arg_dtypes) mangle_result = self.kernel.mangle_function(identifier, arg_dtypes)
if mangle_result is not None: if mangle_result is not None:
...@@ -118,7 +119,25 @@ def perform_cast(ccm, expr, expr_dtype, target_dtype): ...@@ -118,7 +119,25 @@ def perform_cast(ccm, expr, expr_dtype, target_dtype):
# {{{ C code mapper # {{{ C code mapper
class LoopyCCodeMapper(CCodeMapper): # type_context may be:
# - 'i' for integer -
# - 'f' for single-precision floating point
# - 'd' for double-precision floating point
# or None for 'no known context'.
def dtype_to_type_context(dtype):
dtype = np.dtype(dtype)
if dtype.kind == 'i':
return 'i'
if dtype in [np.float64, np.complex128]:
return 'd'
if dtype in [np.float32, np.complex64]:
return 'f'
return None
class LoopyCCodeMapper(RecursiveMapper):
def __init__(self, kernel, seen_dtypes, seen_functions, var_subst_map={}, def __init__(self, kernel, seen_dtypes, seen_functions, var_subst_map={},
with_annotation=False, allow_complex=False): with_annotation=False, allow_complex=False):
""" """
...@@ -127,7 +146,6 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -127,7 +146,6 @@ class LoopyCCodeMapper(CCodeMapper):
functions that were encountered. functions that were encountered.
""" """
CCodeMapper.__init__(self)
self.kernel = kernel self.kernel = kernel
self.seen_dtypes = seen_dtypes self.seen_dtypes = seen_dtypes
self.seen_functions = seen_functions self.seen_functions = seen_functions
...@@ -138,6 +156,8 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -138,6 +156,8 @@ class LoopyCCodeMapper(CCodeMapper):
self.with_annotation = with_annotation self.with_annotation = with_annotation
self.var_subst_map = var_subst_map.copy() self.var_subst_map = var_subst_map.copy()
# {{{ copy helpers
def copy(self, var_subst_map=None): def copy(self, var_subst_map=None):
if var_subst_map is None: if var_subst_map is None:
var_subst_map = self.var_subst_map var_subst_map = self.var_subst_map
...@@ -146,11 +166,6 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -146,11 +166,6 @@ class LoopyCCodeMapper(CCodeMapper):
with_annotation=self.with_annotation, with_annotation=self.with_annotation,
allow_complex=self.allow_complex) allow_complex=self.allow_complex)
def infer_type(self, expr):
result = self.type_inf_mapper(expr)
self.seen_dtypes.add(result)
return result
def copy_and_assign(self, name, value): def copy_and_assign(self, name, value):
"""Make a copy of self with variable *name* fixed to *value*.""" """Make a copy of self with variable *name* fixed to *value*."""
var_subst_map = self.var_subst_map.copy() var_subst_map = self.var_subst_map.copy()
...@@ -164,18 +179,41 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -164,18 +179,41 @@ class LoopyCCodeMapper(CCodeMapper):
var_subst_map.update(assignments) var_subst_map.update(assignments)
return self.copy(var_subst_map=var_subst_map) return self.copy(var_subst_map=var_subst_map)
def map_common_subexpression(self, expr, prec): # }}}
# {{{ helpers
def infer_type(self, expr):
result = self.type_inf_mapper(expr)
self.seen_dtypes.add(result)
return result
def join_rec(self, joiner, iterable, prec, type_context):
f = joiner.join("%s" for i in iterable)
return f % tuple(self.rec(i, prec, type_context) for i in iterable)
def parenthesize_if_needed(self, s, enclosing_prec, my_prec):
if enclosing_prec > my_prec:
return "(%s)" % s
else:
return s
# }}}
def map_common_subexpression(self, expr, prec, type_context):
raise RuntimeError("common subexpression should have been eliminated upon " raise RuntimeError("common subexpression should have been eliminated upon "
"entry to loopy") "entry to loopy")
def map_variable(self, expr, prec): def map_variable(self, expr, enclosing_prec, type_context):
if expr.name in self.var_subst_map: if expr.name in self.var_subst_map:
if self.with_annotation: if self.with_annotation:
return " /* %s */ %s" % ( return " /* %s */ %s" % (
expr.name, expr.name,
self.rec(self.var_subst_map[expr.name], prec)) self.rec(self.var_subst_map[expr.name],
enclosing_prec, type_context))
else: else:
return str(self.rec(self.var_subst_map[expr.name], prec)) return str(self.rec(self.var_subst_map[expr.name],
enclosing_prec, type_context))
elif expr.name in self.kernel.arg_dict: elif expr.name in self.kernel.arg_dict:
arg = self.kernel.arg_dict[expr.name] arg = self.kernel.arg_dict[expr.name]
from loopy.kernel import _ShapedArg from loopy.kernel import _ShapedArg
...@@ -188,15 +226,22 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -188,15 +226,22 @@ class LoopyCCodeMapper(CCodeMapper):
_, c_name = result _, c_name = result
return c_name return c_name
return CCodeMapper.map_variable(self, expr, prec) return expr.name
def map_tagged_variable(self, expr, enclosing_prec): def map_tagged_variable(self, expr, enclosing_prec, type_context):
return expr.name return expr.name
def map_subscript(self, expr, enclosing_prec): def map_subscript(self, expr, enclosing_prec, type_context):
def base_impl(expr, enclosing_prec, type_context):
return self.parenthesize_if_needed(
"%s[%s]" % (
self.rec(expr.aggregate, PREC_CALL, type_context),
self.rec(expr.index, PREC_NONE, 'i')),
enclosing_prec, PREC_CALL)
from pymbolic.primitives import Variable from pymbolic.primitives import Variable
if not isinstance(expr.aggregate, Variable): if not isinstance(expr.aggregate, Variable):
return CCodeMapper.map_subscript(self, expr, enclosing_prec) return base_impl(expr, enclosing_prec, type_context)
if expr.aggregate.name in self.kernel.arg_dict: if expr.aggregate.name in self.kernel.arg_dict:
arg = self.kernel.arg_dict[expr.aggregate.name] arg = self.kernel.arg_dict[expr.aggregate.name]
...@@ -207,7 +252,7 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -207,7 +252,7 @@ class LoopyCCodeMapper(CCodeMapper):
base_access = ("read_imagef(%s, loopy_sampler, (float%d)(%s))" base_access = ("read_imagef(%s, loopy_sampler, (float%d)(%s))"
% (arg.name, arg.dimensions, % (arg.name, arg.dimensions,
", ".join(self.rec(idx, PREC_NONE) ", ".join(self.rec(idx, PREC_NONE, 'i')
for idx in expr.index[::-1]))) for idx in expr.index[::-1])))
if arg.dtype == np.float32: if arg.dtype == np.float32:
...@@ -239,10 +284,11 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -239,10 +284,11 @@ class LoopyCCodeMapper(CCodeMapper):
return "*" + expr.aggregate.name return "*" + expr.aggregate.name
from pymbolic.primitives import Subscript from pymbolic.primitives import Subscript
return CCodeMapper.map_subscript(self, return base_impl(
Subscript(expr.aggregate, arg.offset+sum( Subscript(expr.aggregate, arg.offset+sum(
stride*expr_i for stride, expr_i in zip( stride*expr_i for stride, expr_i in zip(
ary_strides, index_expr))), enclosing_prec) ary_strides, index_expr))),
enclosing_prec, type_context)
elif expr.aggregate.name in self.kernel.temporary_variables: elif expr.aggregate.name in self.kernel.temporary_variables:
...@@ -252,53 +298,68 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -252,53 +298,68 @@ class LoopyCCodeMapper(CCodeMapper):
else: else:
index = (expr.index,) index = (expr.index,)
return (temp_var.name + "".join("[%s]" % self.rec(idx, PREC_NONE) return (temp_var.name + "".join("[%s]" % self.rec(idx, PREC_NONE, 'i')
for idx in index)) for idx in index))
else: else:
raise RuntimeError("nothing known about variable '%s'" % expr.aggregate.name) raise RuntimeError("nothing known about variable '%s'" % expr.aggregate.name)
def map_floor_div(self, expr, prec): def map_floor_div(self, expr, enclosing_prec, type_context):
from loopy.isl_helpers import is_nonnegative from loopy.isl_helpers import is_nonnegative
num_nonneg = is_nonnegative(expr.numerator, self.kernel.domain) num_nonneg = is_nonnegative(expr.numerator, self.kernel.domain)
den_nonneg = is_nonnegative(expr.denominator, self.kernel.domain) den_nonneg = is_nonnegative(expr.denominator, self.kernel.domain)
if den_nonneg: if den_nonneg:
if num_nonneg: if num_nonneg:
return CCodeMapper.map_floor_div(self, expr, prec) return self.parenthesize_if_needed(
"%s // %s" % (
self.rec(expr.numerator, PREC_PRODUCT, type_context),
# analogous to ^{-1}
self.rec(expr.denominator, PREC_POWER, type_context)),
enclosing_prec, PREC_PRODUCT)
else: else:
return ("int_floor_div_pos_b(%s, %s)" return ("int_floor_div_pos_b(%s, %s)"
% (self.rec(expr.numerator, PREC_NONE), % (self.rec(expr.numerator, PREC_NONE, 'i'),
expr.denominator)) self.rec(expr.denominator, PREC_NONE, 'i')))
else: else:
return ("int_floor_div(%s, %s)" return ("int_floor_div(%s, %s)"
% (self.rec(expr.numerator, PREC_NONE), % (self.rec(expr.numerator, PREC_NONE, 'i'),
self.rec(expr.denominator, PREC_NONE))) self.rec(expr.denominator, PREC_NONE, 'i')))
def map_min(self, expr, prec): def map_min(self, expr, prec, type_context):
what = type(expr).__name__.lower() what = type(expr).__name__.lower()
children = expr.children[:] children = expr.children[:]
result = self.rec(children.pop(), PREC_NONE) result = self.rec(children.pop(), PREC_NONE, type_context)
while children: while children:
result = "%s(%s, %s)" % (what, result = "%s(%s, %s)" % (what,
self.rec(children.pop(), PREC_NONE), self.rec(children.pop(), PREC_NONE, type_context),
result) result)
return result return result
map_max = map_min map_max = map_min
def map_constant(self, expr, enclosing_prec): def map_constant(self, expr, enclosing_prec, type_context):
if isinstance(expr, complex): if isinstance(expr, complex):
# FIXME: type-variable cast_type = "cdouble_t"
return "(cdouble_t) (%s, %s)" % (repr(expr.real), repr(expr.imag)) if type_context == "f":
cast_type = "cfloat_t"
return "(%s) (%s, %s)" % (cast_type, repr(expr.real), repr(expr.imag))
else: else:
# FIXME: type-variable if type_context == "f":
return repr(float(expr)) return repr(float(expr))+"f"
elif type_context == "d":
return repr(float(expr))
elif type_context == "i":
return str(int(expr))
else:
raise RuntimeError("don't know how to generated code "
"for constant '%s'" % expr)
def map_call(self, expr, enclosing_prec): def map_call(self, expr, enclosing_prec, type_context):
from pymbolic.primitives import Variable from pymbolic.primitives import Variable
from pymbolic.mapper.stringifier import PREC_NONE from pymbolic.mapper.stringifier import PREC_NONE
...@@ -311,7 +372,7 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -311,7 +372,7 @@ class LoopyCCodeMapper(CCodeMapper):
par_dtypes = tuple(self.infer_type(par) for par in expr.parameters) par_dtypes = tuple(self.infer_type(par) for par in expr.parameters)
parameters = expr.parameters str_parameters = None
mangle_result = self.kernel.mangle_function(identifier, par_dtypes) mangle_result = self.kernel.mangle_function(identifier, par_dtypes)
if mangle_result is not None: if mangle_result is not None:
...@@ -320,23 +381,28 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -320,23 +381,28 @@ class LoopyCCodeMapper(CCodeMapper):
elif len(mangle_result) == 3: elif len(mangle_result) == 3:
result_dtype, c_name, arg_tgt_dtypes = mangle_result result_dtype, c_name, arg_tgt_dtypes = mangle_result
parameters = [ str_parameters = [
perform_cast(self, par, par_dtype, tgt_dtype) self.rec(
perform_cast(self, par, par_dtype, tgt_dtype),
PREC_NONE, dtype_to_type_context(tgt_dtype))
for par, par_dtype, tgt_dtype in zip( for par, par_dtype, tgt_dtype in zip(
parameters, par_dtypes, arg_tgt_dtypes)] expr.parameters, par_dtypes, arg_tgt_dtypes)]
else: else:
raise RuntimeError("result of function mangler " raise RuntimeError("result of function mangler "
"for function '%s' not understood" "for function '%s' not understood"
% identifier) % identifier)
self.seen_functions.add((identifier, c_name, par_dtypes)) self.seen_functions.add((identifier, c_name, par_dtypes))
if str_parameters is None:
str_parameters = [
self.rec(par, PREC_NONE, type_context)
for par in expr.parameters]
if c_name is None: if c_name is None:
raise RuntimeError("unable to find C name for function identifier '%s'" raise RuntimeError("unable to find C name for function identifier '%s'"
% identifier) % identifier)
return self.format("%s(%s)", return "%s(%s)" % (c_name, ", ".join(str_parameters))
c_name, self.join_rec(", ", parameters, PREC_NONE))
# {{{ deal with complex-valued variables # {{{ deal with complex-valued variables
...@@ -348,15 +414,22 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -348,15 +414,22 @@ class LoopyCCodeMapper(CCodeMapper):
else: else:
raise RuntimeError raise RuntimeError
def map_sum(self, expr, enclosing_prec): def map_sum(self, expr, enclosing_prec, type_context):
from pymbolic.mapper.stringifier import PREC_SUM
def base_impl(expr, enclosing_prec, type_context):
return self.parenthesize_if_needed(
self.join_rec(" + ", expr.children, PREC_SUM, type_context),
enclosing_prec, PREC_SUM)
if not self.allow_complex: if not self.allow_complex:
return CCodeMapper.map_sum(self, expr, enclosing_prec) return base_impl(expr, enclosing_prec, type_context)
tgt_dtype = self.infer_type(expr) tgt_dtype = self.infer_type(expr)
is_complex = tgt_dtype.kind == 'c' is_complex = tgt_dtype.kind == 'c'
if not is_complex: if not is_complex:
return CCodeMapper.map_sum(self, expr, enclosing_prec) return base_impl(expr, enclosing_prec, type_context)
else: else:
tgt_name = self.complex_type_name(tgt_dtype) tgt_name = self.complex_type_name(tgt_dtype)
...@@ -365,9 +438,8 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -365,9 +438,8 @@ class LoopyCCodeMapper(CCodeMapper):
complexes = [child for child in expr.children complexes = [child for child in expr.children
if 'c' == self.infer_type(child).kind] if 'c' == self.infer_type(child).kind]
from pymbolic.mapper.stringifier import PREC_SUM real_sum = self.join_rec(" + ", reals, PREC_SUM, type_context)
real_sum = self.join_rec(" + ", reals, PREC_SUM) complex_sum = self.join_rec(" + ", complexes, PREC_SUM, type_context)
complex_sum = self.join_rec(" + ", complexes, PREC_SUM)
if real_sum: if real_sum:
result = "%s_fromreal(%s) + %s" % (tgt_name, real_sum, complex_sum) result = "%s_fromreal(%s) + %s" % (tgt_name, real_sum, complex_sum)
...@@ -376,15 +448,22 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -376,15 +448,22 @@ class LoopyCCodeMapper(CCodeMapper):
return self.parenthesize_if_needed(result, enclosing_prec, PREC_SUM) return self.parenthesize_if_needed(result, enclosing_prec, PREC_SUM)
def map_product(self, expr, enclosing_prec): def map_product(self, expr, enclosing_prec, type_context):
def base_impl(expr, enclosing_prec, type_context):
# Spaces prevent '**z' (times dereference z), which
# is hard to read.
return self.parenthesize_if_needed(
self.join_rec(" * ", expr.children, PREC_PRODUCT, type_context),
enclosing_prec, PREC_PRODUCT)
if not self.allow_complex: if not self.allow_complex:
return CCodeMapper.map_product(self, expr, enclosing_prec) return base_impl(expr, enclosing_prec, type_context)
tgt_dtype = self.infer_type(expr) tgt_dtype = self.infer_type(expr)
is_complex = 'c' == tgt_dtype.kind is_complex = 'c' == tgt_dtype.kind
if not is_complex: if not is_complex:
return CCodeMapper.map_product(self, expr, enclosing_prec) return base_impl(expr, enclosing_prec, type_context)
else: else:
tgt_name = self.complex_type_name(tgt_dtype) tgt_name = self.complex_type_name(tgt_dtype)
...@@ -393,19 +472,18 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -393,19 +472,18 @@ class LoopyCCodeMapper(CCodeMapper):
complexes = [child for child in expr.children complexes = [child for child in expr.children
if 'c' == self.infer_type(child).kind] if 'c' == self.infer_type(child).kind]
from pymbolic.mapper.stringifier import PREC_PRODUCT real_prd = self.join_rec("*", reals, PREC_PRODUCT, type_context)
real_prd = self.join_rec("*", reals, PREC_PRODUCT)
if len(complexes) == 1: if len(complexes) == 1:
myprec = PREC_PRODUCT myprec = PREC_PRODUCT
else: else:
myprec = PREC_NONE myprec = PREC_NONE
complex_prd = self.rec(complexes[0], myprec) complex_prd = self.rec(complexes[0], myprec, type_context)
for child in complexes[1:]: for child in complexes[1:]:
complex_prd = "%s_mul(%s, %s)" % ( complex_prd = "%s_mul(%s, %s)" % (
tgt_name, complex_prd, tgt_name, complex_prd,
self.rec(child, PREC_NONE)) self.rec(child, PREC_NONE, type_context))
if real_prd: if real_prd:
# elementwise semantics are correct # elementwise semantics are correct
...@@ -415,9 +493,19 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -415,9 +493,19 @@ class LoopyCCodeMapper(CCodeMapper):
return self.parenthesize_if_needed(result, enclosing_prec, PREC_PRODUCT) return self.parenthesize_if_needed(result, enclosing_prec, PREC_PRODUCT)
def map_quotient(self, expr, enclosing_prec): def map_quotient(self, expr, enclosing_prec, type_context):
def base_impl(expr, enclosing_prec, type_context):
return self.parenthesize_if_needed(
"%s / %s" % (
# space is necessary--otherwise '/*' becomes
# start-of-comment in C.
self.rec(expr.numerator, PREC_PRODUCT, type_context),
# analogous to ^{-1}
self.rec(expr.denominator, PREC_POWER, type_context)),
enclosing_prec, PREC_PRODUCT)
if not self.allow_complex: if not self.allow_complex:
return CCodeMapper.map_quotient(self, expr, enclosing_prec) return base_impl(expr, enclosing_prec, type_context)
n_complex = 'c' == self.infer_type(expr.numerator).kind n_complex = 'c' == self.infer_type(expr.numerator).kind
d_complex = 'c' == self.infer_type(expr.denominator).kind d_complex = 'c' == self.infer_type(expr.denominator).kind
...@@ -425,36 +513,48 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -425,36 +513,48 @@ class LoopyCCodeMapper(CCodeMapper):
tgt_dtype = self.infer_type(expr) tgt_dtype = self.infer_type(expr)
if not (n_complex or d_complex): if not (n_complex or d_complex):
return CCodeMapper.map_quotient(self, expr, enclosing_prec) return base_impl(expr, enclosing_prec, type_context)
elif n_complex and not d_complex: elif n_complex and not d_complex:
# elementwise semnatics are correct # elementwise semnatics are correct
return CCodeMapper.map_quotient(self, expr, enclosing_prec) return base_impl(expr, enclosing_prec, type_context)
elif not n_complex and d_complex: elif not n_complex and d_complex:
return "%s_rdivide(%s, %s)" % ( return "%s_rdivide(%s, %s)" % (
self.complex_type_name(tgt_dtype), self.complex_type_name(tgt_dtype),
self.rec(expr.numerator, PREC_NONE), self.rec(expr.numerator, PREC_NONE, type_context),
self.rec(expr.denominator, PREC_NONE)) self.rec(expr.denominator, PREC_NONE, type_context))
else: else:
return "%s_divide(%s, %s)" % ( return "%s_divide(%s, %s)" % (
self.complex_type_name(tgt_dtype), self.complex_type_name(tgt_dtype),
self.rec(expr.numerator, PREC_NONE), self.rec(expr.numerator, PREC_NONE, type_context),
self.rec(expr.denominator, PREC_NONE)) self.rec(expr.denominator, PREC_NONE, type_context))
def map_remainder(self, expr, enclosing_prec):
if not self.allow_complex:
return CCodeMapper.map_remainder(self, expr, enclosing_prec)
def map_remainder(self, expr, enclosing_prec, type_context):
tgt_dtype = self.infer_type(expr) tgt_dtype = self.infer_type(expr)
if 'c' == tgt_dtype.kind: if 'c' == tgt_dtype.kind:
raise RuntimeError("complex remainder not defined") raise RuntimeError("complex remainder not defined")
return CCodeMapper.map_remainder(self, expr, enclosing_prec) return "(%s %% %s)" % (
self.rec(expr.numerator, PREC_PRODUCT, type_context),
self.rec(expr.denominator, PREC_POWER, type_context)) # analogous to ^{-1}
def map_power(self, expr, enclosing_prec, type_context):
def base_impl(expr, enclosing_prec, type_context):
from pymbolic.mapper.stringifier import PREC_NONE
from pymbolic.primitives import is_constant, is_zero
if is_constant(expr.exponent):
if is_zero(expr.exponent):
return "1"
elif is_zero(expr.exponent - 1):
return self.rec(expr.base, enclosing_prec, type_context)
elif is_zero(expr.exponent - 2):
return self.rec(expr.base*expr.base, enclosing_prec, type_context)
return "pow(%s, %s)" % (
self.rec(expr.base, PREC_NONE, type_context),
self.rec(expr.exponent, PREC_NONE, type_context))
def map_power(self, expr, enclosing_prec):
if not self.allow_complex: if not self.allow_complex:
return CCodeMapper.map_power(self, expr, enclosing_prec) return base_impl(expr, enclosing_prec, type_context)
from pymbolic.mapper.stringifier import PREC_NONE
tgt_dtype = self.infer_type(expr) tgt_dtype = self.infer_type(expr)
if 'c' == tgt_dtype.kind: if 'c' == tgt_dtype.kind:
...@@ -462,7 +562,7 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -462,7 +562,7 @@ class LoopyCCodeMapper(CCodeMapper):
value = expr.base value = expr.base
for i in range(expr.exponent-1): for i in range(expr.exponent-1):
value = value * expr.base value = value * expr.base
return self.rec(value, enclosing_prec) return self.rec(value, enclosing_prec, type_context)
else: else:
b_complex = 'c' == self.infer_type(expr.base).kind b_complex = 'c' == self.infer_type(expr.base).kind
e_complex = 'c' == self.infer_type(expr.exponent).kind e_complex = 'c' == self.infer_type(expr.exponent).kind
...@@ -470,18 +570,22 @@ class LoopyCCodeMapper(CCodeMapper): ...@@ -470,18 +570,22 @@ class LoopyCCodeMapper(CCodeMapper):
if b_complex and not e_complex: if b_complex and not e_complex:
return "%s_powr(%s, %s)" % ( return "%s_powr(%s, %s)" % (
self.complex_type_name(tgt_dtype), self.complex_type_name(tgt_dtype),
self.rec(expr.base, PREC_NONE), self.rec(expr.base, PREC_NONE, type_context),
self.rec(expr.exponent, PREC_NONE)) self.rec(expr.exponent, PREC_NONE, type_context))
else: else:
return "%s_pow(%s, %s)" % ( return "%s_pow(%s, %s)" % (
self.complex_type_name(tgt_dtype), self.complex_type_name(tgt_dtype),
self.rec(expr.base, PREC_NONE), self.rec(expr.base, PREC_NONE, type_context),
self.rec(expr.exponent, PREC_NONE)) self.rec(expr.exponent, PREC_NONE, type_context))
return CCodeMapper.map_power(self, expr, enclosing_prec) return base_impl(self, expr, enclosing_prec, type_context)
# }}} # }}}
def __call__(self, expr, type_context, prec=PREC_NONE):
from pymbolic.mapper import RecursiveMapper
return RecursiveMapper.__call__(self, expr, prec, type_context)
# }}} # }}}
# vim: fdm=marker # vim: fdm=marker
...@@ -12,11 +12,18 @@ def generate_instruction_code(kernel, insn, codegen_state): ...@@ -12,11 +12,18 @@ def generate_instruction_code(kernel, insn, codegen_state):
expr = insn.expression expr = insn.expression
from loopy.codegen.expression import perform_cast from loopy.codegen.expression import perform_cast
expr = perform_cast(ccm, expr, expr_dtype=ccm.infer_type(expr), target_dtype = kernel.get_var_descriptor(insn.get_assignee_var_name()).dtype
target_dtype=kernel.get_var_descriptor(insn.get_assignee_var_name()).dtype) expr_dtype = ccm.infer_type(expr)
expr = perform_cast(ccm, expr,
expr_dtype=expr_dtype,
target_dtype=target_dtype)
from cgen import Assign from cgen import Assign
insn_code = Assign(ccm(insn.assignee), ccm(expr)) from loopy.codegen.expression import dtype_to_type_context
insn_code = Assign(
ccm(insn.assignee, prec=None, type_context=None),
ccm(expr, prec=None, type_context=dtype_to_type_context(target_dtype)))
from loopy.codegen.bounds import wrap_in_bounds_checks from loopy.codegen.bounds import wrap_in_bounds_checks
insn_inames = kernel.insn_inames(insn) insn_inames = kernel.insn_inames(insn)
insn_code, impl_domain = wrap_in_bounds_checks( insn_code, impl_domain = wrap_in_bounds_checks(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment