diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index 9ba97a48880294e07e77d89eae8b740056a0c1f4..ea0cd495837454673ddcf3f3c6dd260bd61f97c0 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -9,8 +9,12 @@ from pymbolic.mapper import CombineMapper # {{{ type inference class TypeInferenceMapper(CombineMapper): - def __init__(self, kernel): + def __init__(self, kernel, temporary_variables=None): self.kernel = kernel + if temporary_variables is None: + temporary_variables = kernel.temporary_variables + + self.temporary_variables = temporary_variables def combine(self, dtypes): dtypes = list(dtypes) @@ -56,7 +60,7 @@ class TypeInferenceMapper(CombineMapper): pass try: - return self.kernel.temporary_variables[expr.name].dtype + return self.temporary_variables[expr.name].dtype except KeyError: pass @@ -72,15 +76,8 @@ class TypeInferenceMapper(CombineMapper): class LoopyCCodeMapper(CCodeMapper): def __init__(self, kernel, cse_name_list=[], var_subst_map={}, with_annotation=False, allow_complex=False): - def constant_mapper(c): - if isinstance(c, float): - # FIXME: type-variable - return "%sf" % repr(c) - else: - return repr(c) - CCodeMapper.__init__(self, constant_mapper=constant_mapper, - cse_name_list=cse_name_list) + CCodeMapper.__init__(self, cse_name_list=cse_name_list) self.kernel = kernel self.infer_type = TypeInferenceMapper(kernel) self.allow_complex = allow_complex @@ -217,6 +214,16 @@ class LoopyCCodeMapper(CCodeMapper): map_max = map_min + def map_constant(self, expr, enclosing_prec): + if isinstance(expr, complex): + # FIXME: type-variable + return "(cdouble_t) (%s, %s)" % (repr(expr.real), repr(expr.imag)) + elif isinstance(expr, float): + # FIXME: type-variable + return "%s" % repr(expr) + else: + return CCodeMapper.map_constant(self, expr, enclosing_prec) + # {{{ deal with complex-valued variables def complex_type_name(self, dtype):