diff --git a/pymbolic/mapper/c_code.py b/pymbolic/mapper/c_code.py index 7d99f939bc28ace0164878a45d3fb748ca3adac2..4a1b9b7ec22c764f045add605344f9f30fd2a99a 100644 --- a/pymbolic/mapper/c_code.py +++ b/pymbolic/mapper/c_code.py @@ -5,19 +5,29 @@ from pymbolic.mapper.stringifier import SimplifyingSortingStringifyMapper class CCodeMapper(SimplifyingSortingStringifyMapper): def __init__(self, constant_mapper=repr, reverse=True, - cse_prefix="_cse", complex_constant_base_type="double"): + cse_prefix="_cse", complex_constant_base_type="double", + cse_name_list=[]): SimplifyingSortingStringifyMapper.__init__(self, constant_mapper, reverse) self.cse_prefix = cse_prefix - self.cse_to_name = {} - self.cse_names = set() - self.cse_name_list = [] + self.cse_to_name = dict((cse, name) for name, cse in cse_name_list) + self.cse_names = set(cse for name, cse in cse_name_list) + self.cse_name_list = cse_name_list[:] self.complex_constant_base_type = complex_constant_base_type + def copy(self, cse_name_list=None): + if cse_name_list is None: + cse_name_list = self.cse_name_list + return CCodeMapper(self.constant_mapper, self.reverse, + self.cse_prefix, self.complex_constant_base_type, + cse_name_list) + + def copy_with_mapped_cses(self, cses_and_values): + return self.copy(self.cse_name_list + cses_and_values) + # mappings ---------------------------------------------------------------- def map_constant(self, x, enclosing_prec): - import numpy if isinstance(x, complex): return "std::complex<%s>(%s, %s)" % ( self.complex_constant_base_type,