diff --git a/dagrt/codegen/cxx.py b/dagrt/codegen/cxx.py new file mode 100644 index 0000000000000000000000000000000000000000..e968ccb6c156e3b8ffa39d10f7e632c87711485e --- /dev/null +++ b/dagrt/codegen/cxx.py @@ -0,0 +1,2937 @@ +"""CXX test code generator""" +from __future__ import division + +__copyright__ = "Copyright (C) 2014 Matt Wala, Andreas Kloeckner, Cory Mikida" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys +from functools import partial +import re # noqa +import six + +from dagrt.codegen.expressions import CXXExpressionMapper +from dagrt.codegen.codegen_base import StructuredCodeGenerator +from dagrt.utils import is_state_variable +from dagrt.data import UserType +from pytools.py_codegen import ( + # It's the same code. So sue me. + PythonCodeGenerator as CXXEmitterBase) +from pymbolic.primitives import (Call, CallWithKwargs, Variable, + Subscript, Lookup) +from pymbolic.mapper import IdentityMapper +from dagrt.codegen.utils import (wrap_line_base, KeyToUniqueNameMap, + make_identifier_from_name) + + +def pad_cxx(line, width): + line += ' ' * (width - 1 - len(line)) + return line + + +wrap_line = partial(wrap_line_base, pad_func=pad_cxx) + + +# {{{ name manager + +class CXXNameManager(object): + """Maps names that appear in intermediate code to CXX identifiers. + """ + + def __init__(self): + from pytools import UniqueNameGenerator + self.name_generator = UniqueNameGenerator() + self.local_map = KeyToUniqueNameMap(name_generator=self.name_generator) + self.global_map = KeyToUniqueNameMap(start={ + '': 'dagrt_t', '
': 'dagrt_dt'}, + name_generator=self.name_generator) + self.function_map = KeyToUniqueNameMap(name_generator=self.name_generator) + + def name_global(self, var): + """Return the identifier for a global variable.""" + return self.global_map.get_or_make_name_for_key(var) + + def name_local(self, var, prefix=None): + """Return the identifier for a local variable.""" + if prefix is None: + if not var.startswith("dagrt_"): + prefix = "lploc_" + + return self.local_map.get_or_make_name_for_key(var, prefix=prefix) + + def name_function(self, var): + """Return the identifier for a function.""" + return self.function_map.get_or_make_name_for_key(var) + + def make_unique_cxx_name(self, prefix): + return self.local_map.get_mapped_identifier_without_key("drtcxx_"+prefix) + + def is_known_cxx_name(self, name): + return self.name_generator.is_name_conflicting(name) + + def __getitem__(self, name): + """Provide an interface to the expression mapper to look up + the name of a local or global variable. + """ + if is_state_variable(name): + return 'dagrt_state.'+self.name_global(name) + else: + return self.name_local(name) + + +# }}} + + +# {{{ custom emitters + +class CXXEmitter(CXXEmitterBase): + def incorporate(self, sub_generator): + for line in sub_generator.code: + self(line) + + +class CXXBlockEmitter(CXXEmitter): + def __init__(self, what, code_generator=None): + super(CXXBlockEmitter, self).__init__() + self.what = what + + self.code_generator = code_generator + + def __enter__(self): + if self.code_generator is not None: + self.code_generator.emitters.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.code_generator is not None: + self.code_generator.emitters.pop() + + +class CXXHeaderEmitter(CXXBlockEmitter): + def __init__(self, header_name): + super(CXXHeaderEmitter, self).__init__('header') + self.header_name = header_name + self('// {header_name}'.format(header_name=header_name)) + self('#include ') + self('#include ') + self('#include ') + self('#include ') + self('#include ') + self('#include ') + self('#include ') + self('using namespace std;') + self('') + + +class CXXSubblockEmitter(CXXBlockEmitter): + def __init__(self, parent_emitter, what, code_generator=None): + super(CXXSubblockEmitter, self).__init__(what, code_generator) + self.parent_emitter = parent_emitter + + def __exit__(self, exc_type, exc_val, exc_tb): + super(CXXSubblockEmitter, self).__exit__( + exc_type, exc_val, exc_tb) + + self.dedent() + if isinstance(self, CXXStructEmitter): + self('};') + else: + self('}') + self('') + self.parent_emitter.incorporate(self) + + +class CXXIfEmitter(CXXSubblockEmitter): + def __init__(self, parent_emitter, expr, code_generator=None): + super(CXXIfEmitter, self).__init__( + parent_emitter, "if", code_generator) + self("if ({expr})".format(expr=expr)) + self('{') + self.indent() + + def emit_else(self): + self.dedent() + self('}') + self('else') + self('{') + self.indent() + + def emit_else_if(self, expr): + self("else if ({expr})".format(expr=expr)) + self('{') + self.indent() + + def emit_end(self): + self('') + + +class CXXElseIfEmitter(CXXSubblockEmitter): + def __init__(self, parent_emitter, expr, code_generator=None): + super(CXXElseIfEmitter, self).__init__( + parent_emitter, "if", code_generator) + self("else if ({expr})".format(expr=expr)) + self('{') + self.indent() + + +class CXXElseEmitter(CXXSubblockEmitter): + def __init__(self, parent_emitter, code_generator=None): + super(CXXElseEmitter, self).__init__( + parent_emitter, "if", code_generator) + self("else") + self('{') + self.indent() + + +class CXXForEmitter(CXXSubblockEmitter): + def __init__(self, parent_emitter, loop_var, lower_bound, + upper_bound, code_generator=None): + super(CXXForEmitter, self).__init__( + parent_emitter, "for", code_generator) + + self("for (int {loop_var}={low_bound}; " + "{loop_var} <= {high_bound}; {loop_var}++)".format( + loop_var=loop_var, low_bound=lower_bound, high_bound=upper_bound)) + self('{') + self.indent() + + +class CXXVoidFunctionEmitter(CXXSubblockEmitter): + def __init__(self, parent_emitter, name, arg_string, code_generator=None): + super(CXXVoidFunctionEmitter, self).__init__( + parent_emitter, 'void', code_generator) + self.name = name + + self('void %s(%s)' % (name, arg_string)) + self('{') + self.indent() + + +class CXXStructEmitter(CXXSubblockEmitter): + def __init__(self, parent_emitter, type_name, code_generator=None): + super(CXXStructEmitter, self).__init__( + parent_emitter, 'type', code_generator) + self('struct {type_name} {{'.format(type_name=type_name)) + self.indent() + +# }}} + + +# {{{ code generation for function calls + +class CallCode(object): + """Encapsulation for a CXX code template embodying a dagrt-level function call. + """ + + def __init__(self, template, extra_args=None): + """ + :arg extra_args: a dictionary of names that should be made available + in template expansion. + """ + + from mako.template import Template + + self.template = Template(template, strict_undefined=True) + + self.extra_args = extra_args + + def __call__(self, results, function, args, arg_kinds, + code_generator): + from dagrt.codegen.utils import ( + remove_common_indentation, + remove_redundant_blank_lines) + + def add_declaration(decl): + code_generator.declaration_emitter(decl) + + def declare_new(decl_without_name, prefix): + new_name = code_generator.name_manager.make_unique_cxx_name(prefix) + code_generator.declaration_emitter(decl_without_name + " " + + new_name + ";") + return new_name + + import dagrt.data as kinds + + template_names = dict( + real_scalar_kind=code_generator.real_scalar_kind, + complex_scalar_kind=code_generator.complex_scalar_kind, + get_new_identifier=( + code_generator.name_manager.make_unique_cxx_name), + add_declaration=add_declaration, + declare_new=declare_new, + kinds=kinds) + + result_names = getattr(function, "result_names", ("result",)) + + assert len(result_names) == len(results) + for res_name, res in zip(result_names, results): + template_names[res_name] = res + + template_names.update(zip(function.arg_names, args)) + + template_names.update( + (name+"_kind", kind) + for name, kind in zip(function.arg_names, arg_kinds)) + + if self.extra_args: + template_names.update(self.extra_args) + + rendered = self.template.render(**template_names) + + if sys.version_info < (3,): + rendered = rendered.encode() + + lines = remove_redundant_blank_lines( + remove_common_indentation(rendered)) + + for l in lines: + code_generator.emit(l) + +# }}} + + +# {{{ expression modifiers + +class UserTypeReferenceTransformer(IdentityMapper): + def __init__(self, code_generator): + self.code_generator = code_generator + + def find_sym_kind(self, expr): + if isinstance(expr, (Subscript, Lookup)): + return self.find_sym_kind(expr.aggregate) + elif isinstance(expr, Variable): + return self.code_generator.sym_kind_table.get( + self.code_generator.current_function, expr.name) + else: + raise TypeError("unsupported object") + + def transform(self, expr): + raise NotImplementedError + + def map_variable(self, expr): + if isinstance(self.find_sym_kind(expr), UserType): + return self.transform(expr) + else: + return expr + + map_lookup = map_variable + map_subscript = map_variable + + +class StructureLookupAppender(UserTypeReferenceTransformer): + def __init__(self, code_generator, component): + super(StructureLookupAppender, self).__init__(code_generator) + self.component = component + + def transform(self, expr): + return expr.attr(self.component) + + +class SharedPointerGetAppender(UserTypeReferenceTransformer): + def __init__(self, code_generator): + super(SharedPointerGetAppender, self).__init__(code_generator) + self.get = 'get()' + + def transform(self, expr): + return expr.attr(self.get) + + +class ArraySubscriptAppender(UserTypeReferenceTransformer): + def __init__(self, code_generator, subscript): + super(ArraySubscriptAppender, self).__init__(code_generator) + self.subscript = subscript + + def transform(self, expr): + return expr[self.subscript] + +# }}} + + +# {{{ cxx 'vector-ish' types + +class TypeBase(object): + """ + .. attribute:: base_type + .. attribute:: dimension + + A tuple of ``'200'``, ``'-5:5'``, or some such. + Entries may be numeric, too. + """ + + def get_base_type(self): + raise NotImplementedError() + + def get_type_specifiers(self, defer_dim): + raise NotImplementedError() + + def is_allocatable(self): + raise NotImplementedError() + + +class BuiltinType(TypeBase): + def __init__(self, type_name): + self.type_name = type_name + + def get_base_type(self): + return self.type_name + + def get_type_specifiers(self, defer_dim): + return () + + def is_allocatable(self): + return False + + +# Allocatable arrays are not yet supported, use pointers for now. +class ArrayType(TypeBase): + """ + .. attribute:: dimension + + A tuple of ``'200'``, ``'-5:5'``, or some such. + Entries may be numeric, too. Or they may refer + to variables that are available through + *extra_arguments* in :class:`CodeGenerator`. + + .. attribute:: index_vars + + If further dimensions within :attr:`element_type` depend + on indices into :attr:`dimension`, this tuple of variable + names determines what each index variable is called. + """ + + @staticmethod + def parse_dimension(dim): + parts = ":".split(dim) + + if len(parts) == 1: + return ("0", "%s - 1" % dim) + elif len(parts) == 2: + return tuple(parts) + else: + raise RuntimeError( + "unexpected number of parts in dimension spec '%s'" + % dim) + + INDEX_VAR_COUNTER = 0 + + def __init__(self, dimension, element_type, index_vars=None): + self.element_type = element_type + if isinstance(dimension, str): + dimension = tuple(d.strip() for d in dimension.split(",")) + self.dimension = tuple(str(i) for i in dimension) + + if isinstance(index_vars, str): + index_vars = tuple(iv.strip() for iv in index_vars.split(",")) + elif index_vars is None: + def get_index_var(): + ArrayType.INDEX_VAR_COUNTER += 1 + return "i%d" % ArrayType.INDEX_VAR_COUNTER + + index_vars = tuple(get_index_var() for d in dimension) + + if len(index_vars) != len(dimension): + raise ValueError("length of 'index_vars' does not match length " + "of 'dimension'") + + if not isinstance(element_type, TypeBase): + raise TypeError("element_type should be a subclass of TypeBase") + if isinstance(element_type, PointerType): + raise TypeError("Arrays of pointers are not allowed in CXX. " + "You must declare an intermediate StructureType instead.") + + self.index_vars = index_vars + + def get_base_type(self): + return self.element_type.get_base_type() + + def get_type_specifiers(self, defer_dim): + result = self.element_type.get_type_specifiers(defer_dim) + + return result + + def is_allocatable(self): + return self.element_type.is_allocatable() + + +class PointerType(TypeBase): + """ + .. attribute:: pointee_type + .. attribute:: is_synthetic + + A :class:`bool` flag indicating whether this pointer declaration + is genuinely part of the user type (False) or 'synthetically' inserted + by the code generator (True). + """ + + def __init__(self, pointee_type, is_synthetic=False): + self.pointee_type = pointee_type + self.is_synthetic = is_synthetic + + if not isinstance(pointee_type, TypeBase): + raise TypeError("pointee_type should be a subclass of TypeBase") + + def get_base_type(self): + return self.pointee_type.get_base_type() + + def get_type_specifiers(self, defer_dim): + if isinstance(self.pointee_type, ArrayType): + if isinstance(self.pointee_type.element_type, BuiltinType): + return ("std::shared_ptr<%s>" % self.pointee_type.get_base_type(),) + else: + #return ("std::shared_ptr<%s[]>" % + # self.pointee_type.get_base_type(),) + return ("std::shared_ptr<%s>" % self.pointee_type.get_base_type(),) + else: + return ("std::shared_ptr<%s>" % self.pointee_type.get_base_type(),) + + def is_allocatable(self): + return True + + +class StructureType(TypeBase): + """ + .. attribute:: members + + A tuple of **(name, type)** tuples. + """ + + def __init__(self, type_name, members): + self.type_name = type_name + self.members = members + + for i, (_, mtype) in enumerate(members): + if not isinstance(mtype, TypeBase): + raise TypeError("member with index %d has type that is " + "not a subclass of TypeBase" + % i) + + def get_base_type(self): + return "%s" % self.type_name + + def get_type_specifiers(self, defer_dim): + return () + + def is_allocatable(self): + return any( + member_type.is_allocatable() + for name, member_type in self.members) + +# }}} + + +# {{{ type visitor + +# {{{ helper functionality + +def _replace_indices(index_expr_map, s): + for name, expr in six.iteritems(index_expr_map): + s, _ = re.subn(r"\b" + name + r"\b", expr, s) + return s + + +class _ArrayLoopManager(object): + def __init__(self, array_type, code_generator): + self.array_type = array_type + self.code_generator = code_generator + + self.f_index_names = [ + code_generator.name_manager.make_unique_cxx_name(iname) + for iname in array_type.index_vars] + + self.f_dim_names = [ + code_generator.name_manager.make_unique_cxx_name(iname[-6:]) + for iname in array_type.dimension] + + def enter(self, index_expr_map, allow_parallel_do): + atype = self.array_type + + cg = self.code_generator + + self.emitters = [] + for iloop, (dim, index_name, dim_name) in enumerate( + reversed(list(zip(atype.dimension, self.f_index_names, + self.f_dim_names)))): + cg.declaration_emitter('int %s;' % index_name) + cg.declaration_emitter('int %s;' % dim_name) + + start, stop = atype.parse_dimension(dim) + + self.code_generator.emit( + "{name} = {expr};" + .format( + name=dim_name, + expr=_replace_indices(index_expr_map, stop))) + + em = CXXForEmitter( + cg.emitter, index_name, + _replace_indices(index_expr_map, start), + _replace_indices(index_expr_map, dim_name), + cg) + self.emitters.append(em) + em.__enter__() + + def update_index_expr_map(self, index_expr_map): + index_expr_map.update(zip(self.array_type.index_vars, self.f_index_names)) + + def get_loop_subscript(self): + from pymbolic import var + return tuple(var(""+v) for v in self.f_index_names) + + def leave(self): + while self.emitters: + em = self.emitters.pop() + em.__exit__(None, None, None) + +# }}} + + +class TypeVisitor(object): + recurse_only_if_allocatable = False + + def rec(self, cxx_type, *args, **kwargs): + return getattr(self, "visit_"+type(cxx_type).__name__)( + cxx_type, *args, **kwargs) + + __call__ = rec + + +class DeclarationGenerator(TypeVisitor): + def __init__(self, use_deferred_shape): + self.use_deferred_shape = use_deferred_shape + + def visit_BuiltinType(self, cxx_type): + return (cxx_type.type_name) + + def visit_ArrayType(self, cxx_type): + return self.rec(cxx_type.element_type) + + def visit_PointerType(self, cxx_type): + return self.rec(cxx_type.pointee_type) + + def visit_StructureType(self, cxx_type): + return ("%s" % cxx_type.type_name) + + +class CodeGeneratingTypeVisitor(TypeVisitor): + def __init__(self, code_generator): + self.code_generator = code_generator + + def visit_BuiltinType(self, cxx_type, cxx_expr_str, index_expr_map, + *args, **kwargs): + pass + + def visit_ArrayType(self, cxx_type, cxx_expr_str, index_expr_map, + *args, **kwargs): + if (self.recurse_only_if_allocatable + and not cxx_type.element_type.is_allocatable()): + return + + alm = _ArrayLoopManager(cxx_type, self.code_generator) + alm.enter(index_expr_map, allow_parallel_do=False) + + index_expr_map = index_expr_map.copy() + alm.update_index_expr_map(index_expr_map) + + self.rec(cxx_type.element_type, + "%s[%s]" % (cxx_expr_str, ", ".join(alm.f_index_names)), + index_expr_map, *args, **kwargs) + + alm.leave() + + def visit_PointerType(self, cxx_type, cxx_expr_str, index_expr_map, + *args, **kwargs): + if (self.recurse_only_if_allocatable + and not cxx_type.pointee_type.is_allocatable()): + return + + self.rec(cxx_type.pointee_type, cxx_expr_str, index_expr_map, + *args, **kwargs) + + def visit_StructureType(self, cxx_type, cxx_expr_str, index_expr_map, + *args, **kwargs): + for member_name, member_type in cxx_type.members: + if (self.recurse_only_if_allocatable + and not member_type.is_allocatable()): + continue + + self.rec(member_type, + cxx_expr_str+"."+member_name, + index_expr_map, + *args, **kwargs) + + +class PointerAliasCreatingArraySubscriptAppender(ArraySubscriptAppender): + """Used to hoist array base subexpressions out of assignments.""" + + def __init__(self, code_generator, subscript, cxx_type): + super(PointerAliasCreatingArraySubscriptAppender, self).__init__( + code_generator, subscript) + self.expr_to_alias = {} + self.cxx_type = cxx_type + + def transform(self, expr): + try: + expr_cxx_name = self.expr_to_alias[expr] + except KeyError: + expr_cxx_name = ( + self.code_generator.name_manager.make_unique_cxx_name( + "hoisted")) + + dg = DeclarationGenerator(use_deferred_shape=True) + self.code_generator.declaration_emitter(dg(self.cxx_type) + + " *" + expr_cxx_name + ";") + + self.code_generator.emit( + "{name} = {expr}" + .format( + name=expr_cxx_name, + expr=self.code_generator.expr(expr))) + + self.expr_to_alias[expr] = expr_cxx_name + + from pymbolic import var + return var(""+expr_cxx_name)[self.subscript] + + +class AssignmentEmitter(CodeGeneratingTypeVisitor): + def visit_BuiltinType(self, cxx_type, cxx_expr_str, index_expr_map, + rhs_expr, is_rhs_target): + self.code_generator.emit( + "{name} = {expr};" + .format( + name=cxx_expr_str, + expr=self.code_generator.expr(rhs_expr))) + + def visit_ArrayType(self, cxx_type, cxx_expr_str, index_expr_map, + rhs_expr, is_rhs_target): + el_is_primitive = isinstance(cxx_type.element_type, BuiltinType) + + cg = self.code_generator + + alm = _ArrayLoopManager(cxx_type, cg) + + if el_is_primitive and is_rhs_target: + lhs_cxx_name = cg.name_manager.make_unique_cxx_name("hoisted") + + dg = DeclarationGenerator(use_deferred_shape=True) + #cg.declaration_emitter("std::shared_ptr<%s[]>" % (dg(cxx_type)) + cg.declaration_emitter("std::shared_ptr<%s>" % (dg(cxx_type)) + + lhs_cxx_name + ";") + + cg.emit( + "{name} = {expr};" + .format( + name=lhs_cxx_name, + expr=cxx_expr_str)) + + transformer = PointerAliasCreatingArraySubscriptAppender( + self.code_generator, alm.get_loop_subscript(), + cxx_type=cxx_type) + transformer2 = ArraySubscriptAppender( + self.code_generator, alm.get_loop_subscript()) + rhs_expr = transformer2(rhs_expr) + + else: + lhs_cxx_name = cxx_expr_str + transformer = ArraySubscriptAppender( + self.code_generator, alm.get_loop_subscript()) + + spga = SharedPointerGetAppender(self.code_generator) + rhs_expr = transformer(spga(rhs_expr)) + + alm.enter(index_expr_map, + allow_parallel_do=el_is_primitive) + index_expr_map = index_expr_map.copy() + alm.update_index_expr_map(index_expr_map) + + self.rec(cxx_type.element_type, + "%s.get()[%s]" % (lhs_cxx_name, ", ".join(alm.f_index_names)), + index_expr_map, rhs_expr, is_rhs_target=is_rhs_target) + + alm.leave() + + def visit_PointerType(self, cxx_type, cxx_expr_str, index_expr_map, + rhs_expr, is_rhs_target): + if (self.recurse_only_if_allocatable + and not cxx_type.pointee_type.is_allocatable()): + return + + self.rec(cxx_type.pointee_type, cxx_expr_str, index_expr_map, + rhs_expr, is_rhs_target=True and not cxx_type.is_synthetic) + + def visit_StructureType(self, cxx_type, cxx_expr_str, index_expr_map, + rhs_expr, is_rhs_target): + for member_name, member_type in cxx_type.members: + sla = StructureLookupAppender(self.code_generator, member_name) + spga = SharedPointerGetAppender(self.code_generator) + + self.rec(member_type, + cxx_expr_str+"."+member_name, + index_expr_map, + spga(sla(rhs_expr)), + is_rhs_target=is_rhs_target) + + +class AllocationEmitter(CodeGeneratingTypeVisitor): + recurse_only_if_allocatable = True + + def visit_PointerType(self, cxx_type, cxx_expr_str, index_expr_map): + pointee_type = cxx_type.pointee_type + code_generator = self.code_generator + + dimension = "" + if isinstance(pointee_type, ArrayType): + if pointee_type.dimension: + dimension = '[%s]' % ', '.join( + str(dim_axis) for dim_axis in pointee_type.dimension) + code_generator.emit_traceable( + '{name}.reset(new {base_type}{dimension});'.format( + name=cxx_expr_str, + base_type=pointee_type.get_base_type(), + dimension=_replace_indices(index_expr_map, dimension))) + if pointee_type.is_allocatable(): + cxx_expr_str += ".get()" + self.rec(pointee_type, cxx_expr_str, index_expr_map) + + +class DeallocationEmitter(CodeGeneratingTypeVisitor): + recurse_only_if_allocatable = True + + def __init__(self, code_generator, deinitializer): + super(DeallocationEmitter, self).__init__(code_generator) + self.deinitializer = deinitializer + + def visit_PointerType(self, cxx_type, cxx_expr_str, index_expr_map): + pointee_type = cxx_type.pointee_type + code_generator = self.code_generator + + if pointee_type.is_allocatable(): + if isinstance(pointee_type, ArrayType): + cxx_rec_str = cxx_expr_str + ".get()" + else: + cxx_rec_str = cxx_expr_str + self.rec(pointee_type, cxx_rec_str, index_expr_map) + + self.deinitializer(pointee_type, cxx_expr_str + ".get()", index_expr_map) + + code_generator.emit_traceable('{id}.reset();'.format(id=cxx_expr_str)) + code_generator.emit_traceable("%s = nullptr;" % cxx_expr_str) + + +class InitializationEmitter(CodeGeneratingTypeVisitor): + recurse_only_if_allocatable = True + + def visit_PointerType(self, cxx_type, cxx_expr_str, index_expr_map): + self.code_generator.emit_traceable("%s = nullptr;" % cxx_expr_str) + +# }}} + + +# {{{ code generator + +class CodeGenerator(StructuredCodeGenerator): + """ + Generates a CXX header of name *header_name*, which defines a type + *dagrt_state_type* to hold the current state of the time integrator + along with several functions:: + + initialize(EXTRA_ARGUMENTS, dagrt_state, ...) + run(EXTRA_ARGUMENTS, dagrt_state) + shutdown(EXTRA_ARGUMENTS, dagrt_state) + print_profile(dagrt_state) + + *dagrt_state* is of type *dagrt_state_type`, and *EXTRA_ARGUMENTS* above matches + *extra_arguments* as passed to the constructor. + + The ``...`` arguments to ``initialize`` are optional and must be passed by + keyword. The following keywords arguments are available: + + * *dagrt_dt*: The initial time step size + * *dagrt_t*: The initial time + * *state_STATE*: The initial value for the :mod:`dagrt` variable ``STATE`` + + .. rubric:: Profiling information + + The following attributes are available and allowed for read access in + *dagrt_state_type* while outside of *run*: + + * *dagrt_state_PHASE_count* + * *dagrt_state_PHASE_failures* + * *dagrt_state_PHASE_time* + + * *dagrt_func_FUNC_count* + * *dagrt_func_FUNC_time* + + In all of the above, upper case denotes a "metavariable"--e.g. *PHASE* is + the name of a phase, or *FUNC* is the name of a function. The name of a + function will typically be ``something``, for which *FUNC* will be + ``func_something``. As a result, the profile field counting the number of + invocations of the function ``something`` will be named + *dagrt_func_func_something*. + """ + + language = "cxx" + + # {{{ constructor + + def __init__(self, header_name, + user_type_map, + function_registry=None, + header_preamble=None, + real_scalar_kind="8", + complex_scalar_kind="8", + use_complex_scalars=True, + call_before_state_update=None, + call_after_state_update=None, + extra_arguments=(), + extra_argument_types=(), + emit_instrumentation=False, + timing_function=None, + + trace=False): + """ + :arg function_registry: An instance of + :class:`dagrt.function_registry.FunctionRegistry` + :arg module_preamble: A string to include at the beginning of the + emitted module + :arg user_type_map: a map from user type names + to :class:`CXXType` instances + :arg call_before_state_update: The name of a function that should + be called before each state update. The function must be known + to *function_registry*. + :arg call_after_state_update: The name of a function that should + be called after each state update. The function must be known + to *function_registry*. + :arg extra_arguments: A tuple of names of extra arguments that are + prepended to the call signature of each generated function + and are available to refer to in user-supplied function + implementations. + :arg emit_instrumentation: True or False, whether to emit performance + instrumentation. + :arg timing_function: *None* or the name of a function that returns + wall clock time as a number of seconds, as a ``real*8``. + Required if *emit_instrumentation* is set to *True*. + """ + if function_registry is None: + from dagrt.function_registry import base_function_registry + function_registry = base_function_registry + + for type_name, type_val in six.iteritems(user_type_map): + # Because if it's already a pointer, we have a hard time declaring + # the input type of our memory management routines. + + if isinstance(type_val, PointerType): + raise ValueError("type '%s': PointerType is not allowed as the " + "outermost type in user type mappings" + % type_name) + + self.header_name = header_name + self.function_registry = function_registry + self.user_type_map = user_type_map + + self.trace = trace + + from dagrt.codegen.utils import remove_common_indentation + self.header_preamble = remove_common_indentation(header_preamble) + + self.real_scalar_kind = real_scalar_kind + self.complex_scalar_kind = complex_scalar_kind + self.use_complex_scalars = use_complex_scalars + self.call_before_state_update = call_before_state_update + self.call_after_state_update = call_after_state_update + + if isinstance(extra_arguments, str): + extra_arguments = tuple(s.strip() for s in extra_arguments.split(",")) + if isinstance(extra_argument_types, str): + extra_argument_types = tuple( + s.strip() for s in extra_argument_types.split(",")) + + self.extra_arguments = extra_arguments + self.extra_argument_types = extra_argument_types + + self.emit_instrumentation = emit_instrumentation + self.timing_function = timing_function + if emit_instrumentation and timing_function is None: + raise ValueError("must supply timing_function if " + "emit_instrumentation is True") + + self.name_manager = CXXNameManager() + self.expr_mapper = CXXExpressionMapper( + self.name_manager) + + self.function_and_arg_kinds_to_cxx_name = {} + + # FIXME: Should make extra arguments known to + # name manager + + self.header_emitter = CXXHeaderEmitter(header_name) + self.header_emitter.__enter__() + + self.emitters = [self.header_emitter] + + self.current_function = None + self.used = False + + # }}} + + # {{{ utilities + + def get_called_function_names(self, dag): + from dagrt.codegen.analysis import collect_function_names_from_dag + result = collect_function_names_from_dag(dag, no_expressions=True) + + if self.call_before_state_update: + result.add(self.call_before_state_update) + + if self.call_after_state_update: + result.add(self.call_after_state_update) + + return sorted(result) + + def get_alloc_check_name(self, utype_id): + return "dagrt_alloc_check_"+make_identifier_from_name(utype_id) + + def get_var_deinit_name(self, utype_id): + return "dagrt_deinit_"+make_identifier_from_name(utype_id) + + @staticmethod + def phase_name_to_phase_sym(phase_name): + return "dagrt_phase_"+phase_name + + @staticmethod + def component_name_to_component_sym(comp_name): + return "dagrt_component_"+comp_name + + @property + def emitter(self): + return self.emitters[-1] + + def expr(self, expr): + return self.expr_mapper(expr) + + def emit(self, line): + self.emitter(line) + + def emit_traceable(self, line): + self.emit_trace(line) + self.emit(line) + + def emit_trace(self, line): + if self.trace: + self.emit("fprintf(stdout, \"%s\");" % line) + + # }}} + + # {{{ main entrypoint + + def __call__(self, dag): + if self.used: + raise RuntimeError("cxx code generator may not be " + "used more than once") + self.used = True + + from dagrt.codegen.analysis import verify_code + verify_code(dag) + + from dagrt.codegen.transform import ( + eliminate_self_dependencies, + isolate_function_arguments, + isolate_function_calls, + expand_IfThenElse) + dag = eliminate_self_dependencies(dag) + dag = isolate_function_arguments(dag) + dag = isolate_function_calls(dag) + dag = expand_IfThenElse(dag) + + # from dagrt.language import show_dependency_graph + # show_dependency_graph(dag) + + # {{{ produce function name / function AST pairs + + from dagrt.codegen.dag_ast import ( + create_ast_from_phase, get_statements_in_ast) + + from collections import namedtuple + NameASTPair = namedtuple("NameASTPair", "name, ast") # noqa + fdescrs = [] + + for phase_name in sorted(six.iterkeys(dag.phases)): + ast = create_ast_from_phase(dag, phase_name) + fdescrs.append(NameASTPair(phase_name, ast)) + + # }}} + + from dagrt.data import SymbolKindFinder + + self.sym_kind_table = SymbolKindFinder(self.function_registry)( + [fd.name for fd in fdescrs], + [get_statements_in_ast(fd.ast) for fd in fdescrs]) + + from dagrt.codegen.analysis import ( + collect_ode_component_names_from_dag, + var_to_last_dependent_statement_mapping) + + component_ids = collect_ode_component_names_from_dag(dag) + + self.last_used_stmt_table = var_to_last_dependent_statement_mapping( + [fd.name for fd in fdescrs], + [get_statements_in_ast(fd.ast) for fd in fdescrs]) + + if not component_ids <= set(self.user_type_map): + raise RuntimeError("User type missing from user type map: %r" + % (component_ids - set(self.user_type_map))) + + from dagrt.data import Scalar, UserType + for comp_id in component_ids: + self.sym_kind_table.set( + None, ""+comp_id, Scalar(is_real_valued=True)) + self.sym_kind_table.set( + None, ""+comp_id, Scalar(is_real_valued=True)) + self.sym_kind_table.set( + None, ""+comp_id, UserType(comp_id)) + + self.begin_emit(dag) + # Forward-declare the functions that dagrt phases will call + self.header_body_emitter = CXXEmitter() + self.emitters.append(self.header_body_emitter) + self.fwd_decl_emitter = CXXEmitter() + + for fdescr in sorted(fdescrs, key=lambda fdescr: fdescr.name): + self.lower_function(fdescr.name, fdescr.ast) + self.finish_emit(dag) + + # Incorporate forward declarations + self.emitters[-2].incorporate(self.fwd_decl_emitter) + full_emitter = self.emitters.pop() + self.emitter.incorporate(full_emitter) + + del self.fwd_decl_emitter + del self.sym_kind_table + + code = self.get_code() + + new_lines = [] + for ln in code.split("\n"): + if ln.lstrip().startswith("#"): + hashmark_pos = ln.find("#") + assert hashmark_pos >= 0 + ln = "#" + ln[:hashmark_pos] + ln[hashmark_pos+1:] + + new_lines.append(ln) + + return "\n".join(new_lines) + + # }}} + + # {{{ lower_function + + def lower_function(self, function_name, ast): + self.current_function = function_name + + arg_types = {} + arg_types['&dagrt_state'] = 'dagrt_state_type' + for i, arg in enumerate(self.extra_arguments): + arg_types[arg] = self.extra_argument_types[i] + + args = self.extra_arguments + ('&dagrt_state',) + + arg_string = '' + for arg in args: + arg_string += arg_types[arg] + " " + arg + ", " + + arg_string = arg_string[:-2] + + self.emit_def_begin( + 'dagrt_phase_func_' + function_name, + arg_string, + phase_id=function_name) + + # {{{ instrumentation + + if self.emit_instrumentation: + self.emit( + "dagrt_state.dagrt_phase_{phase}_count " + "= dagrt_state.dagrt_phase_{phase}_count + 1;" + .format(phase=function_name)) + + timer_start_var = self.name_manager.make_unique_cxx_name( + "timer_start") + self.declaration_emitter("double " + timer_start_var + ";") + + self.emit( + "{timer_start_var} = {timing_function}();" + .format( + timer_start_var=timer_start_var, + timing_function=self.timing_function)) + + # }}} + + self.lower_ast(ast) + + self.emit('label999:; // exit label') + + # {{{ emit variable deinit + + sym_table = self.sym_kind_table.per_phase_table.get( + self.current_function, {}) + + for identifier, sym_kind in sorted(six.iteritems(sym_table)): + if (identifier, self.current_function) not in self.last_used_stmt_table: + self.emit_variable_deinit(identifier, sym_kind) + + # }}} + + self.emit_trace('leave %s' % self.current_function) + + # {{{ instrumentation + + if self.emit_instrumentation: + self.emit( + "dagrt_state.dagrt_phase_{phase}_time " + "= dagrt_state.dagrt_phase_{phase}_time " + "+ ({timing_function}() - {timer_start_var});" + .format( + phase=function_name, + timing_function=self.timing_function, + timer_start_var=timer_start_var, + )) + + # }}} + + self.emit_def_end(function_name) + + self.current_function = None + + # }}} + + # {{{ get_code + + def get_code(self): + assert not self.header_emitter.preamble + + indent_spaces = 1 + indentation = indent_spaces*' ' + + wrapped_lines = [] + for l in self.header_emitter.code: + line_leading_spaces = (len(l) - len(l.lstrip(" "))) + level = line_leading_spaces // indent_spaces + line_ind = level*indentation + if l[line_leading_spaces:].startswith("//"): + wrapped_lines.append(l) + else: + for wrapped_line in wrap_line( + l[line_leading_spaces:], + level, indentation=indentation): + wrapped_lines.append(line_ind+wrapped_line) + + return "\n".join(wrapped_lines) + + # }}} + + # {{{ begin/finish_emit + + def begin_emit(self, dag): + if self.header_preamble: + for l in self.header_preamble: + self.emit(l) + self.emit('') + + from dagrt.codegen.analysis import collect_time_ids_from_dag + for i, time_id in enumerate(sorted(collect_time_ids_from_dag(dag))): + self.emit("const int dagrt_time_{time_id} = {i};".format( + time_id=time_id, i=i)) + self.emit('') + + # {{{ phase name constants + + for i, phase in enumerate(sorted(dag.phases)): + phase_sym_name = self.phase_name_to_phase_sym(phase) + self.emit("const int {phase_sym_name} = {i};".format( + phase_sym_name=phase_sym_name, i=i)) + + self.emit('') + + # }}} + + # {{{ component name constants + + from dagrt.codegen.analysis import collect_ode_component_names_from_dag + component_ids = collect_ode_component_names_from_dag(dag) + + for i, comp_id in enumerate(sorted(component_ids)): + comp_sym_name = self.component_name_to_component_sym(comp_id) + self.emit("const int {comp_sym_name} = {i};".format( + comp_sym_name=comp_sym_name, + i=i)) + + self.emit('') + + # }}} + + # {{{ state type + + with CXXStructEmitter( + self.emitter, + 'dagrt_state_type', + self,) as emit: + emit('int dagrt_next_phase;') + emit('') + + for identifier, sym_kind in sorted(six.iteritems( + self.sym_kind_table.global_table)): + self.emit_variable_decl( + self.name_manager.name_global(identifier), + sym_kind=sym_kind) + + # {{{ instrumentation + + if self.emit_instrumentation: + emit('') + emit('// {{{ instrumentation') + emit('') + + for phase_name in sorted(dag.phases): + emit('int dagrt_phase_%s_count;' % phase_name) + emit('int dagrt_phase_%s_failures;' % phase_name) + emit('double dagrt_phase_%s_time;' % phase_name) + + emit('') + + for func_name in self.get_called_function_names(dag): + func_id = make_identifier_from_name(func_name) + emit('int dagrt_func_%s_count;' % func_id) + emit('double dagrt_func_%s_time;' % func_id) + + emit('') + emit('// }}}') + emit('') + + # }}} + + # }}} + + # {{{ memory management routines + + from dagrt.data import collect_user_types + user_types = collect_user_types(self.sym_kind_table) + + # {{{ allocation checks + + for utype_id in sorted(user_types): + val_name = make_identifier_from_name(utype_id) + function_name = self.get_alloc_check_name(utype_id) + + arg_types = {} + for i, arg in enumerate(self.extra_arguments): + arg_types[arg] = self.extra_argument_types[i] + sym_kind = UserType(utype_id) + arg_types["&" + val_name] = self.get_arg_type_name("&" + val_name, + sym_kind, is_pointer=True) + + args = self.extra_arguments + ("&" + val_name,) + + arg_string = '' + for arg in args: + arg_string += arg_types[arg] + " " + arg + ", " + + arg_string = arg_string[:-2] + + self.emit_def_begin(function_name, + arg_string) + + with CXXIfEmitter( + self.emitter, + '%s == nullptr' % val_name, self): + + cxxtype = self.get_cxx_type_for_user_type(utype_id) + AllocationEmitter(self)(cxxtype, val_name, {}) + + with CXXElseEmitter( + self.emitter, self): + + # If the refcount is 1, then nobody else is referring to + # the memory, and we might as well repurpose/overwrite it, + # so there's nothing more to do in that case. + + with CXXIfEmitter( + self.emitter, "%s.use_count() != 1" % val_name, self): + + # We get here if the refcount is not 1 initially, which + # means it's not zero here--someone else is still + # referring to the data. Let them have it, we'll make + # a new array. + + self.emit('') + + AllocationEmitter(self)(cxxtype, val_name, {}) + + self.emit_def_end(function_name) + + # }}} + + # {{{ deinit + + for utype_id in sorted(user_types): + val_name = make_identifier_from_name(utype_id) + function_name = self.get_var_deinit_name(utype_id) + + arg_types = {} + for i, arg in enumerate(self.extra_arguments): + arg_types[arg] = self.extra_argument_types[i] + sym_kind = UserType(utype_id) + cxxtype = self.get_cxx_type_for_user_type(sym_kind.identifier) + + arg_types["&" + val_name] = self.get_arg_type_name("&" + val_name, + sym_kind, is_pointer=True) + + args = self.extra_arguments + ("&" + val_name,) + + arg_string = '' + for arg in args: + arg_string += arg_types[arg] + " " + arg + ", " + + arg_string = arg_string[:-2] + + self.emit_def_begin(function_name, + arg_string) + + sym_kind = UserType(utype_id) + + with CXXIfEmitter( + self.emitter, '%s != nullptr' % val_name, self): + with CXXIfEmitter( + self.emitter, "%s.use_count() == 1" % val_name, self): + cxxtype = self.get_cxx_type_for_user_type(sym_kind.identifier) + DeallocationEmitter(self, InitializationEmitter(self))( + cxxtype, val_name, {}) + + with CXXElseEmitter( + self.emitter, self): + InitializationEmitter(self)(cxxtype, val_name, {}) + + self.emit_def_end(function_name) + + # }}} + + # }}} + + def finish_emit(self, dag): + for (function_id, arg_kinds), cxx_name in six.iteritems( + self.function_and_arg_kinds_to_cxx_name): + self.emit_dagrt_function(cxx_name, function_id, arg_kinds) + + self.emit_initialize(dag) + self.emit_shutdown() + self.emit_run_step(dag) + + self.emit_print_profile(dag) + + self.header_emitter.__exit__(None, None, None) + + self.emit("// vim:foldmethod=marker") + + # }}} + + # {{{ data management + + def get_cxx_type_for_user_type(self, type_identifier, is_argument=False): + cxxtype = self.user_type_map[type_identifier] + if not is_argument: + cxxtype = PointerType(cxxtype, is_synthetic=True) + + return cxxtype + + def emit_variable_decl(self, cxx_name, sym_kind, + is_argument=False, other_specifiers=(), emit=None): + if emit is None: + emit = self.emit + + from dagrt.data import Boolean, Scalar, Array, Integer + + type_specifiers = other_specifiers + + from dagrt.data import UserType + if isinstance(sym_kind, Boolean): + type_name = 'bool' + + elif isinstance(sym_kind, Array): + if sym_kind.is_real_valued or not self.use_complex_scalars: + type_name = 'std::vector' + else: + type_name = 'double _Complex' + type_specifiers = type_specifiers + + elif isinstance(sym_kind, Scalar): + if sym_kind.is_real_valued or not self.use_complex_scalars: + type_name = 'double' + else: + type_name = 'double _Complex' + + elif isinstance(sym_kind, Integer): + type_name = 'int' + + elif isinstance(sym_kind, UserType): + cxxtype = self.get_cxx_type_for_user_type(sym_kind.identifier) + + type_name = cxxtype.get_base_type() + type_specifiers = ( + type_specifiers + + cxxtype.get_type_specifiers(defer_dim=is_argument)) + + else: + raise ValueError("unknown variable kind: %s" % type(sym_kind).__name__) + + if type_specifiers: + emit('{type_specifier_list} {id};'.format( + type_specifier_list=", ".join(type_specifiers), + id=cxx_name)) + else: + emit('{type_name} {id};'.format( + type_name=type_name, + id=cxx_name)) + + def get_arg_type_name(self, arg_name, arg_kind, + is_pointer=False, emit=None): + if emit is None: + emit = self.emit + + from dagrt.data import Boolean, Scalar, Array, Integer + + from dagrt.data import UserType + if isinstance(arg_kind, Boolean): + type_name = 'bool' + + elif isinstance(arg_kind, Array): + if arg_kind.is_real_valued or not self.use_complex_scalars: + type_name = 'std::vector' + else: + type_name = 'double _Complex' + + elif isinstance(arg_kind, Scalar): + if arg_kind.is_real_valued or not self.use_complex_scalars: + type_name = 'double' + else: + type_name = 'double _Complex' + + elif isinstance(arg_kind, Integer): + type_name = 'int' + + elif isinstance(arg_kind, UserType): + cxxtype = self.get_cxx_type_for_user_type(arg_kind.identifier, + is_argument=False) + + type_name = cxxtype.get_base_type() + + else: + raise ValueError("unknown variable kind: %s %s" % ( + type(arg_kind).__name__, arg_name)) + + if is_pointer: + cxxtype = self.get_cxx_type_for_user_type(arg_kind.identifier, + is_argument=False) + if isinstance(cxxtype.pointee_type, ArrayType): + if isinstance(cxxtype.pointee_type.element_type, BuiltinType): + return ("std::shared_ptr<%s>" % type_name) + else: + #return ("std::shared_ptr<%s[]>" % type_name) + return ("std::shared_ptr<%s>" % type_name) + else: + return ("std::shared_ptr<%s>" % type_name) + else: + return type_name + + def emit_variable_init(self, name, sym_kind): + from dagrt.data import UserType + if isinstance(sym_kind, UserType): + cxxtype = self.get_cxx_type_for_user_type(sym_kind.identifier) + InitializationEmitter(self)(cxxtype, self.name_manager[name], {}) + + def emit_variable_deinit(self, name, sym_kind): + cxx_name = self.name_manager[name] + + from dagrt.data import UserType + if not isinstance(sym_kind, UserType): + return + + self.emit( + "{var_deinit_name}({args});" + .format( + var_deinit_name=self.get_var_deinit_name( + sym_kind.identifier), + args=", ".join( + self.extra_arguments + + (cxx_name,)) + )) + + def emit_refcounted_allocation(self, sym, sym_kind): + cxx_name = self.name_manager[sym] + + from dagrt.data import UserType + if not isinstance(sym_kind, UserType): + return + + cxxtype = self.get_cxx_type_for_user_type(sym_kind.identifier) + AllocationEmitter(self)(cxxtype, cxx_name, {}) + + def emit_allocation_check(self, sym, sym_kind): + cxx_name = self.name_manager[sym] + + self.emit( + "{alloc_check_name}({args});" + .format( + alloc_check_name=self.get_alloc_check_name( + sym_kind.identifier), + args=", ".join( + self.extra_arguments + + (cxx_name,)) + )) + + def emit_user_type_move(self, assignee_sym, assignee_cxx_name, + sym_kind, expr): + self.emit_variable_deinit(assignee_sym, sym_kind) + + self.emit_traceable( + "{name} = {expr};" + .format( + name=assignee_cxx_name, + expr=self.name_manager[expr.name])) + self.emit('') + + def emit_assign_expr_inner(self, + assignee_cxx_name, assignee_subscript, expr, sym_kind): + if assignee_subscript: + subscript_str = "[%s]" % ( + ", ".join( + "%s" % self.expr(i) + for i in assignee_subscript)) + else: + subscript_str = "" + + if isinstance(expr, (Call, CallWithKwargs)): + # These are supposed to have been transformed to AssignFunctionCall. + raise RuntimeError("bare Call/CallWithKwargs encountered in " + "CXX code generator") + + else: + self.emit_trace("{assignee_cxx_name}{subscript_str} = {expr}..." + .format( + assignee_cxx_name=assignee_cxx_name, + subscript_str=subscript_str, + expr=str(expr)[:50])) + + from dagrt.data import UserType + if not isinstance(sym_kind, UserType): + self.emit( + "{name}{subscript_str} = {expr};" + .format( + name=assignee_cxx_name, + subscript_str=subscript_str, + expr=self.expr(expr))) + else: + cxxtype = self.get_cxx_type_for_user_type(sym_kind.identifier) + AssignmentEmitter(self)( + cxxtype, assignee_cxx_name, {}, expr, + is_rhs_target=True) + + self.emit('') + + # }}} + + # {{{ emit_initialize + + def emit_initialize(self, dag): + init_symbols_pre = sorted( + sym + for sym in self.sym_kind_table.global_table + if not sym.startswith(""): + sorter[sym] = 0 + elif sym.startswith(""): + sorter[sym] = 1 + elif sym.startswith("
"): + sorter[sym] = 2 + else: + sorter[sym] = 3 + + init_symbols = sorted(init_symbols_pre, key=sorter.get) + + args = self.extra_arguments + ('&dagrt_state',) + tuple( + self.name_manager.name_global(sym) + ' = nullptr' + for sym in init_symbols) + + arg_types = {} + for i, arg in enumerate(self.extra_arguments): + arg_types[arg] = self.extra_argument_types[i] + arg_types['&dagrt_state'] = 'dagrt_state_type' + + from dagrt.data import UserType + # Make all parameters except dagrt_state default parameters with a default + # value of nullptr... so they all must be pointers + for sym in init_symbols: + sym_kind = self.sym_kind_table.global_table[sym] + cxx_name = self.name_manager.name_global(sym) + ' = nullptr' + arg_type_name = self.get_arg_type_name(cxx_name, sym_kind) + if isinstance(sym_kind, UserType): + cxxtype = self.get_cxx_type_for_user_type(sym_kind.identifier, + is_argument=False) + if isinstance(cxxtype.pointee_type, ArrayType): + if isinstance(cxxtype.pointee_type.element_type, BuiltinType): + arg_types[cxx_name] = "std::shared_ptr<%s>" % arg_type_name + else: + # arg_types[cxx_name] = "std::shared_ptr<%s[]>" % / + # arg_type_name + arg_types[cxx_name] = "std::shared_ptr<%s>" % arg_type_name + else: + arg_types[cxx_name] = "std::shared_ptr<%s>" % arg_type_name + else: + arg_types[cxx_name] = "std::shared_ptr<%s>" % arg_type_name + + arg_string = '' + for arg in args: + arg_string += arg_types[arg] + " " + arg + ", " + + arg_string = arg_string[:-2] + + function_name = 'initialize' + phase_id = ""+function_name + + self.emit_def_begin(function_name, arg_string, phase_id=phase_id) + + for sym in init_symbols: + sym_kind = self.sym_kind_table.global_table[sym] + cxx_name = self.name_manager.name_global(sym) + self.sym_kind_table.set(phase_id, ""+cxx_name, sym_kind) + + self.current_function = phase_id + + arg_types = {} + for sym in init_symbols: + sym_kind = self.sym_kind_table.global_table[sym] + cxx_name = self.name_manager.name_global(sym) + + self.emit( + "dagrt_state.dagrt_next_phase = dagrt_phase_{0};" + .format(dag.initial_phase)) + + for sym, sym_kind in sorted(six.iteritems( + self.sym_kind_table.global_table)): + self.emit_variable_init(sym, sym_kind) + + # {{{ initialize scalar outputs to NaN + + self.emit('') + self.emit('// initialize scalar outputs to NaN') + self.emit('') + + for sym in sorted(self.sym_kind_table.global_table): + sym_kind = self.sym_kind_table.global_table[sym] + + tgt_cxx_name = self.name_manager[sym] + + # All our scalars are floating-point numbers for now, + # so initializing them all to NaN is fine. + + from dagrt.data import Scalar + if sym.startswith(""+cxx_name), + is_rhs_target=True) + + # {{{ instrumentation + + if self.emit_instrumentation: + self.emit('') + self.emit('// {{{ instrumentation') + self.emit('') + + for phase_name in sorted(dag.phases): + self.emit('dagrt_state.dagrt_phase_%s_count = 0;' % phase_name) + self.emit('dagrt_state.dagrt_phase_%s_failures = 0;' % phase_name) + self.emit('dagrt_state.dagrt_phase_%s_time = 0;' % phase_name) + + self.emit('') + + for func_name in self.get_called_function_names(dag): + func_id = make_identifier_from_name(func_name) + self.emit('dagrt_state.dagrt_func_%s_count = 0;' % func_id) + self.emit('dagrt_state.dagrt_func_%s_time = 0;' % func_id) + + self.emit('') + self.emit('// }}}') + self.emit('') + + # }}} + + self.emit_def_end(function_name) + + self.current_function = None + + # }}} + + # {{{ emit_shutdown + + def emit_shutdown(self): + args = self.extra_arguments + ('&dagrt_state',) + + arg_types = {} + for i, arg in enumerate(self.extra_arguments): + arg_types[arg] = self.extra_argument_types[i] + arg_types['&dagrt_state'] = 'dagrt_state_type' + + arg_string = '' + for arg in args: + arg_string += arg_types[arg] + " " + arg + ", " + + arg_string = arg_string[:-2] + + function_name = 'shutdown' + phase_id = ""+function_name + self.emit_def_begin(function_name, arg_string, phase_id=phase_id) + + self.current_function = phase_id + + from dagrt.data import UserType + + for sym, sym_kind in sorted(six.iteritems(self.sym_kind_table.global_table)): + self.emit_variable_deinit(sym, sym_kind) + + for sym, sym_kind in sorted(six.iteritems(self.sym_kind_table.global_table)): + if isinstance(sym_kind, UserType): + cxx_name = self.name_manager[sym] + with CXXIfEmitter( + self.emitter, + '{id} != nullptr'.format(id=cxx_name), self): + self.emit( + 'std::cerr << "leaked reference in " <<' + '{name} << std::endl;' + .format(name=cxx_name)) + + self.emit_def_end(function_name) + + self.current_function = None + + # }}} + + # {{{ emit_run_step + + def emit_run_step(self, dag): + args = self.extra_arguments + ('&dagrt_state',) + + arg_types = {} + for i, arg in enumerate(self.extra_arguments): + arg_types[arg] = self.extra_argument_types[i] + arg_types['&dagrt_state'] = 'dagrt_state_type' + + arg_string = '' + for arg in args: + arg_string += arg_types[arg] + " " + arg + ", " + + arg_string = arg_string[:-2] + + function_name = 'run' + phase_id = ""+function_name + self.emit_def_begin(function_name, arg_string, phase_id=phase_id) + + self.current_function = phase_id + + # Modify for proper state passing + args = self.extra_arguments + ('dagrt_state',) + + if_emit = None + for name, phase_descr in sorted(six.iteritems(dag.phases)): + phase_sym_name = self.phase_name_to_phase_sym(name) + cond = "dagrt_state.dagrt_next_phase == "+phase_sym_name + + if if_emit is None: + if_emit = CXXIfEmitter( + self.emitter, cond, self) + if_emit.__enter__() + else: + if_emit = CXXElseIfEmitter( + self.emitter, cond, self) + if_emit.__enter__() + + self.emit( + "dagrt_state.dagrt_next_phase = " + + self.phase_name_to_phase_sym(phase_descr.next_phase) + ";") + + self.emit( + "dagrt_phase_func_{phase_name}({args});".format( + phase_name=name, + args=", ".join(args))) + + if_emit.__exit__(None, None, None) + + if if_emit: + with CXXElseEmitter(self.emitter, self): + self.emit('std::cerr << "encountered invalid phase in run" << ' + "dagrt_state.dagrt_next_phase << std::endl;") + + self.emit_def_end(function_name) + + self.current_function = None + + # }}} + + # {{{ emit_print_profile + + def emit_print_profile(self, dag): + args = ('&dagrt_state',) + + arg_types = {} + for i, arg in enumerate(self.extra_arguments): + arg_types[arg] = self.extra_argument_types[i] + arg_types['&dagrt_state'] = 'dagrt_state_type' + + arg_string = '' + for arg in args: + arg_string += arg_types[arg] + " " + arg + ", " + + arg_string = arg_string[:-2] + + function_name = 'print_profile' + self.emit_def_begin(function_name, arg_string) + + if self.emit_instrumentation: + delim = "-" * 75 + self.emit("fprintf(stdout, \"%s\\n\");" % delim) + self.emit("fprintf(stdout, \"dagrt profile information\\n\");") + self.emit("fprintf(stdout, \"%s\\n\");" % delim) + + for phase_name in sorted(dag.phases): + self.emit( + "std::cout << \"phase {phase} count: \" << " + .format(phase=phase_name)) + self.emit( + "dagrt_state.dagrt_phase_{phase}_count << std::endl;" + .format(phase=phase_name)) + self.emit( + "std::cout << \"phase {phase} failures: \" << " + .format(phase=phase_name)) + self.emit( + "dagrt_state.dagrt_phase_{phase}_failures << std::endl;" + .format(phase=phase_name)) + with CXXIfEmitter( + self.emitter, + 'dagrt_state.dagrt_phase_{phase}_count > 0' + .format(phase=phase_name), + self): + self.emit("double mean_time_{phase} = " + "dagrt_state.dagrt_phase_{phase}_time / " + "dagrt_state.dagrt_phase_{phase}_count;" + .format(phase=phase_name)) + self.emit( + "std::cout << \"phase {phase} mean time: \" << " + .format(phase=phase_name)) + self.emit( + "mean_time_{phase} << std::endl;" + .format(phase=phase_name)) + self.emit( + "std::cout << \"phase {phase} total time: \" << " + .format(phase=phase_name)) + self.emit( + "dagrt_state.dagrt_phase_{phase}_time << std::endl;" + .format(phase=phase_name)) + + self.emit('') + self.emit("fprintf(stdout, \"%s\\n\");" % delim) + self.emit('') + + for func_name in self.get_called_function_names(dag): + func_id = make_identifier_from_name(func_name) + self.emit( + "std::cout << \"function {func_name} count: \" <<" + .format(func_name=func_name)) + self.emit( + "dagrt_state.dagrt_func_{func_id}_count << std::endl;" + .format(func_id=func_id)) + + with CXXIfEmitter( + self.emitter, + 'dagrt_state.dagrt_func_{func_id}_count > 0' + .format(func_id=func_id), + self): + self.emit( + "double mean_time_{func_id} = " + "dagrt_state.dagrt_func_{func_id}_time / " + "dagrt_state.dagrt_func_{func_id}_count;" + .format(func_id=func_id)) + self.emit( + "std::cout << \"function {func_name} mean time: \" << " + .format(func_name=func_name)) + self.emit( + "mean_time_{func_id} << std::endl;" + .format(func_id=func_id)) + + self.emit( + "std::cout << \"function {func_name} total time: \" << " + .format(func_name=func_name)) + self.emit( + "dagrt_state.dagrt_func_{func_id}_time << std::endl;" + .format(func_id=func_id)) + + self.emit("fprintf(stdout, \"%s\\n\");" % delim) + + self.emit_def_end(function_name) + + # }}} + + # {{{ emit_dagrt_function + + def emit_dagrt_function(self, cxx_name, function_id, arg_kinds): + function = self.function_registry[function_id] + + arg_kinds_dict = dict(zip(function.arg_names, arg_kinds)) + + result_kinds = function.get_result_kinds(arg_kinds_dict, check=True) + + result_names = [self.name_manager.make_unique_cxx_name("res%d" % (i + 1)) + for i in range(len(result_kinds))] + + arg_types = {} + for i, arg in enumerate(self.extra_arguments): + arg_types[arg] = self.extra_argument_types[i] + arg_types['&dagrt_state'] = 'dagrt_state_type' + + from dagrt.data import UserType, Array + + arg_names = [] + for name, arg_kind in zip(function.arg_names, arg_kinds): + if isinstance(arg_kind, UserType): + arg_names.append('*' + name) + arg_types['*' + name] = self.get_arg_type_name(name, arg_kind) + elif isinstance(arg_kind, Array): + arg_names.append('*' + name) + arg_types['*' + name] = self.get_arg_type_name(name, arg_kind) + else: + arg_names.append(name) + arg_types[name] = self.get_arg_type_name(name, arg_kind) + + for name, res_kind in zip(result_names, result_kinds): + arg_names.append('*' + name) + arg_types['*' + name] = self.get_arg_type_name(name, res_kind) + + args = ( + list(self.extra_arguments) + + ["&dagrt_state"] + + list(arg_names)) + + arg_string = '' + for arg in args: + arg_string += arg_types[arg] + " " + arg + ", " + + arg_string = arg_string[:-2] + + self.emit_def_begin(cxx_name, arg_string) + + for name, arg_kind in zip(function.arg_names, arg_kinds): + if arg_kind is None: + # We may encounter None as an arg_kind, for arguments of + # state update notification. + self.declaration_emitter("int "+name+";") + + self.emit("") + + # {{{ instrumentation + + if self.emit_instrumentation: + self.emit( + "dagrt_state.dagrt_func_{func}_count " + "= dagrt_state.dagrt_func_{func}_count + 1;" + .format(func=make_identifier_from_name(function_id))) + + timer_start_var = self.name_manager.make_unique_cxx_name( + "timer_start") + self.declaration_emitter("double " + timer_start_var + ";") + + self.emit( + "{timer_start_var} = {timing_function}();" + .format( + timer_start_var=timer_start_var, + timing_function=self.timing_function)) + + # }}} + + func_codegen = function.get_codegen(self.language) + + func_codegen( + results=result_names, + function=function, + args=function.arg_names, + arg_kinds=arg_kinds, + code_generator=self) + + # {{{ instrumentation + + if self.emit_instrumentation: + self.emit( + "dagrt_state.dagrt_func_{func}_time " + "= dagrt_state.dagrt_func_{func}_time " + "+ ({timing_function}() - {timer_start_var});" + .format( + func=make_identifier_from_name(function_id), + timing_function=self.timing_function, + timer_start_var=timer_start_var, + )) + + # }}} + + self.emit_def_end(cxx_name) + + self.current_function = None + + # }}} + + # {{{ called by superclass + + def emit_def_begin(self, function_name, argument_string, phase_id=None): + self.declaration_emitter = CXXEmitter() + + CXXVoidFunctionEmitter( + self.emitter, + function_name, + argument_string, + self).__enter__() + + body_emitter = CXXEmitter() + self.emitters.append(body_emitter) + + if phase_id is not None: + sym_table = self.sym_kind_table.per_phase_table.get(phase_id, {}) + for identifier, sym_kind in sorted(six.iteritems(sym_table)): + self.emit_variable_decl( + self.name_manager[identifier], sym_kind, is_argument=True) + + if sym_table: + self.emit('') + + self.emit_trace('================================================') + self.emit_trace('enter %s' % function_name) + + if phase_id is not None: + for identifier, sym_kind in sorted(six.iteritems(sym_table)): + self.emit_variable_init(identifier, sym_kind) + + if sym_table: + self.emit('') + + def emit_def_end(self, func_id): + self.emitters[-2].incorporate(self.declaration_emitter) + + body_emitter = self.emitters.pop() + self.emitter.incorporate(body_emitter) + + # body emitter + self.emitter.__exit__(None, None, None) + + del self.declaration_emitter + + def emit_if_begin(self, expr): + CXXIfEmitter( + self.emitter, + self.expr(expr), + self).__enter__() + + def emit_if_end(self,): + self.emitter.__exit__(None, None, None) + + def emit_else_begin(self): + self.emitter.emit_else() # pylint:disable=no-member + + def emit_assign_expr(self, assignee_sym, assignee_subscript, expr): + from dagrt.data import UserType, Array + + assignee_cxx_name = self.name_manager[assignee_sym] + + sym_kind = self.sym_kind_table.get( + self.current_function, assignee_sym) + + if assignee_subscript and not isinstance(sym_kind, Array): + raise TypeError("only arrays support subscripted assignment") + return + + if not isinstance(sym_kind, UserType): + self.emit_assign_expr_inner( + assignee_cxx_name, assignee_subscript, expr, sym_kind) + return + + if assignee_subscript: + raise ValueError("User types do not support subscripting") + + if isinstance(expr, Variable): + self.emit_user_type_move( + assignee_sym, assignee_cxx_name, sym_kind, expr) + return + + from pymbolic import var + from pymbolic.mapper.dependency import DependencyMapper + + # We can't tolerate reading a variable that we're just assigning, + # as we need to make new storage for the assignee before the + # read starts. + assert var(assignee_sym) not in DependencyMapper()(expr) + + self.emit_allocation_check(assignee_sym, sym_kind) + self.emit_assign_expr_inner( + assignee_cxx_name, assignee_subscript, expr, sym_kind) + + def lower_inst(self, inst): + """Emit the code for an statement.""" + + self.emit("// {{{ %s" % inst) + self.emit("") + super(CodeGenerator, self).lower_inst(inst) + self.emit("") + self.emit("// }}}") + self.emit("") + + # {{{ emit_inst_Assign + + def emit_inst_Assign(self, inst): + start_em = self.emitter + + for iloop, (ident, start, stop) in enumerate(inst.loops): + em = CXXForEmitter( + self.emitter, + self.name_manager[ident], + self.expr(start), self.expr(stop-1), + code_generator=self) + em.__enter__() + + self.emit_assign_expr( + inst.assignee, inst.assignee_subscript, inst.expression) + + for ident, start, stop in inst.loops[::-1]: + self.emitter.__exit__(None, None, None) + + self.emit_deinit_for_last_usage_of_vars(inst) + + assert start_em is self.emitter + + # }}} + + # {{{ emit_inst_AssignFunctionCall + + def emit_inst_AssignFunctionCall(self, inst): + self.emit_trace("func call {results} = {expr}..." + .format( + results=", ".join(inst.assignees), + expr=str(inst.as_expression())[:50])) + + arg_strs_dict = {} + arg_kinds_dict = {} + from pymbolic.mapper.dependency import DependencyMapper + from pymbolic import var + from dagrt.data import UserType, Array, Boolean + forward_declare = 0 + for i, arg in enumerate(inst.parameters): + arg_strs_dict[i] = self.expr(arg) + assert isinstance(arg, Variable) + + # FIXME: This can fail for args of state update notification, + # hence the try/catch. + try: + arg_kinds_dict[i] = self.sym_kind_table.get( + self.current_function, arg.name) + except KeyError: + arg_kinds_dict[i] = None + + for arg_name, arg in inst.kw_parameters.items(): + arg_strs_dict[arg_name] = self.expr(arg) + assert isinstance(arg, Variable) + + # FIXME: This can fail for args of state update notification, + # hence the try/catch. + try: + arg_kinds_dict[arg_name] = self.sym_kind_table.get( + self.current_function, arg.name) + except KeyError: + pass + + for assignee_sym in inst.assignees: + sym_kind = self.sym_kind_table.get( + self.current_function, assignee_sym) + if isinstance(sym_kind, UserType): + self.emit_allocation_check(assignee_sym, sym_kind) + + assert var(assignee_sym) not in DependencyMapper()( + inst.as_expression()) + + # Check assignees for pointers that need to be dereferenced. + assignee_cxx_names = [] + for sym in inst.assignees: + sym_kind = self.sym_kind_table.get(self.current_function, sym) + name = self.name_manager[sym] + if isinstance(sym_kind, UserType): + assignee_cxx_names.append(name + '.get()') + elif isinstance(sym_kind, Array): + assignee_cxx_names.append('&' + name) + elif isinstance(sym_kind, Boolean): + assignee_cxx_names.append('&' + name) + else: + assignee_cxx_names.append('&' + name) + + function = self.function_registry[inst.function_id] + + arg_kinds = function.resolve_args(arg_kinds_dict) + + key = (inst.function_id, arg_kinds) + + try: + cxx_func_name = self.function_and_arg_kinds_to_cxx_name[key] + except KeyError: + forward_declare = 1 + cxx_func_name = self.name_manager.make_unique_cxx_name( + inst.function_id) + self.function_and_arg_kinds_to_cxx_name[key] = cxx_func_name + + # Check arguments for pointers that need to be dereferenced. + for arg_name, arg in inst.kw_parameters.items(): + if isinstance(arg_kinds_dict[arg_name], UserType): + # Dereference user types before passing + arg_strs_dict[arg_name] = arg_strs_dict[arg_name] + '.get()' + elif isinstance(arg_kinds_dict[arg_name], Array): + arg_strs_dict[arg_name] = '&' + arg_strs_dict[arg_name] + + for i, arg in enumerate(inst.parameters): + if isinstance(arg_kinds_dict[i], UserType): + # Dereference user types before passing + arg_strs_dict[i] = arg_strs_dict[i] + '.get()' + elif isinstance(arg_kinds_dict[i], Array): + arg_strs_dict[i] = '&' + arg_strs_dict[i] + + self.emit("{cxx_func_name}({args});" + .format( + cxx_func_name=cxx_func_name, + args=", ".join( + list(self.extra_arguments) + + ["dagrt_state"] + + list(function.resolve_args(arg_strs_dict)) + + assignee_cxx_names + ))) + + # For forward declaration of these functions + if forward_declare: + result_kinds = function.get_result_kinds(arg_kinds_dict, check=True) + + result_names = [self.name_manager.make_unique_cxx_name("res%d" % (i + 1)) + for i in range(len(result_kinds))] + + arg_types = {} + arg_types['&dagrt_state'] = 'dagrt_state_type' + for i, arg in enumerate(self.extra_arguments): + arg_types[arg] = self.extra_argument_types[i] + + arg_names = [] + for name, arg_kind in zip(function.arg_names, arg_kinds): + if isinstance(arg_kind, UserType): + arg_names.append('*' + name) + arg_types['*' + name] = self.get_arg_type_name(name, arg_kind) + elif isinstance(arg_kind, Array): + arg_names.append('*' + name) + arg_types['*' + name] = self.get_arg_type_name(name, arg_kind) + else: + arg_names.append(name) + arg_types[name] = self.get_arg_type_name(name, arg_kind) + + for name, res_kind in zip(result_names, result_kinds): + arg_names.append('*' + name) + arg_types['*' + name] = self.get_arg_type_name(name, res_kind) + + args = ( + list(self.extra_arguments) + + ["&dagrt_state"] + + list(arg_names)) + + arg_string = '' + for arg in args: + arg_string += arg_types[arg] + " " + arg + ", " + + arg_string = arg_string[:-2] + + self.fwd_decl_emitter("void %s(%s);" % + (cxx_func_name, arg_string)) + self.fwd_decl_emitter("") + + self.emit_deinit_for_last_usage_of_vars(inst) + + # }}} + + def emit_return(self): + self.emit("goto label999;") + + def emit_inst_YieldState(self, inst): + self.emit_assign_expr( + ''+inst.component_id, + (), + Variable("dagrt_time_"+str(inst.time_id))) + self.emit_assign_expr( + ''+inst.component_id, + (), + inst.time) + + from dagrt.language import AssignFunctionCall + from pymbolic import var + from dagrt.data import Integer + + if self.call_before_state_update: + self.sym_kind_table.set( + self.current_function, + self.component_name_to_component_sym(inst.component_id), + Integer()) + self.emit_inst_AssignFunctionCall( + AssignFunctionCall( + (), + self.call_before_state_update, + (var(self.component_name_to_component_sym( + inst.component_id)),))) + + self.emit_assign_expr( + ''+inst.component_id, + (), + inst.expression) + + if self.call_after_state_update: + self.sym_kind_table.set( + self.current_function, + self.component_name_to_component_sym(inst.component_id), + Integer()) + self.emit_inst_AssignFunctionCall( + AssignFunctionCall( + (), + self.call_after_state_update, + (var(self.component_name_to_component_sym( + inst.component_id)),))) + + def emit_deinit_for_last_usage_of_vars(self, inst): + """Check if, for any of the variables in instruction *inst*, + *inst* contains the last use of that variable in the + :attr:`current_function`. If so, emit code to deallocate that variable. + """ + from dagrt.utils import is_state_variable + + read_and_written = inst.get_read_variables() | inst.get_written_variables() + + for variable in read_and_written: + var_kind = self.sym_kind_table.get( + self.current_function, variable) + + # FIXME: This can fail for args of state update notification, + # hence the try/catch. + try: + last_used_stmt_id = self.last_used_stmt_table[ + variable, self.current_function] + except KeyError: + continue + if inst.id == last_used_stmt_id and not is_state_variable(variable): + self.emit_variable_deinit(variable, var_kind) + + def emit_inst_Raise(self, inst): + # FIXME: Reenable emitting full error message + # TBD: Quoting of quotes, extra-long lines + if inst.error_message: + self.emit("// " + inst.error_message) + + self.emit("std::cerr <<" + "\"{condition}\" << std::endl;".format( + condition=inst.error_condition.__name__)) + + def emit_inst_FailStep(self, inst): + if self.emit_instrumentation: + self.emit( + "dagrt_state.dagrt_phase_{phase}_failures " + "= dagrt_state.dagrt_phase_{phase}_failures + 1;" + .format(phase=self.current_function)) + + self.emit("goto label999;") + + def emit_inst_ExitStep(self, inst): + self.emit("goto label999;") + + def emit_inst_SwitchPhase(self, inst): + self.emit( + 'dagrt_state.dagrt_next_phase = ' + + self.phase_name_to_phase_sym(inst.next_phase) + ';') + + # }}} + +# }}} + + +# {{{ built-in functions + +class TypeVisitorWithResult(CodeGeneratingTypeVisitor): + def __init__(self, code_generator, result_expr): + super(TypeVisitorWithResult, self).__init__(code_generator) + self.result_expr = result_expr + + +class Norm2Computer(TypeVisitorWithResult): + def visit_BuiltinType(self, cxx_type, cxx_expr_str, index_expr_map): + self.code_generator.emit( + "*{result} = *{result} + pow({expr},2);" + .format( + result=self.result_expr, + expr=cxx_expr_str)) + + +def codegen_builtin_norm_2(results, function, args, arg_kinds, + code_generator): + result, = results + + from dagrt.data import Scalar, UserType, Array + x_kind = arg_kinds[0] + if isinstance(x_kind, Scalar): + if x_kind.is_real_valued: + ftype = BuiltinType("double") + else: + ftype = BuiltinType("double _Complex") + elif isinstance(x_kind, UserType): + ftype = code_generator.user_type_map[x_kind.identifier] + + elif isinstance(x_kind, Array): + code_generator.emit("*{result} = norm({arg});".format( + result=result, arg=args[0])) + return + + else: + raise TypeError("unsupported kind for norm_2 argument: %s" % x_kind) + + code_generator.emit("*{result} = 0;".format(result=result)) + code_generator.emit("") + + Norm2Computer(code_generator, result)(ftype, args[0], {}) + + code_generator.emit("") + code_generator.emit("*{result} = sqrt(*{result});".format(result=result)) + code_generator.emit("") + + +class LenComputer(TypeVisitorWithResult): + # FIXME: This could be made *way* more efficient by handling + # arrays of built-in types directly. + + def visit_BuiltinType(self, cxx_type, cxx_expr_str, index_expr_map): + self.code_generator.emit( + "*{result} = *{result} + 1;" + .format( + result=self.result_expr, + expr=cxx_expr_str)) + + +def codegen_builtin_len(results, function, args, arg_kinds, + code_generator): + result, = results + + from dagrt.data import Scalar, Array, UserType + x_kind = arg_kinds[0] + if isinstance(x_kind, Scalar): + if x_kind.is_real_valued: + ftype = BuiltinType("double") + else: + ftype = BuiltinType("double _Complex") + elif isinstance(x_kind, UserType): + ftype = code_generator.user_type_map[x_kind.identifier] + elif isinstance(x_kind, Array): + code_generator.emit("*{result} = ({arg}).size();".format( + result=result, + arg=args[0])) + return + else: + raise TypeError("unsupported kind for len argument: %s" % x_kind) + + code_generator.emit("*{result} = 0;".format(result=result)) + code_generator.emit("") + + LenComputer(code_generator, result)(ftype, args[0], {}) + code_generator.emit("") + + +class IsNaNComputer(TypeVisitorWithResult): + def visit_BuiltinType(self, cxx_type, cxx_expr_str, index_expr_map): + self.code_generator.emit( + "if (isnan({expr}) == true || isinf({expr}) == true) {{" + .format( + expr=cxx_expr_str)) + self.code_generator.emit( + " *{result} = true;" + .format( + result=self.result_expr)) + self.code_generator.emit("}") + + +def codegen_builtin_isnan(results, function, args, arg_kinds, + code_generator): + result, = results + + from dagrt.data import Scalar, UserType + x_kind = arg_kinds[0] + if isinstance(x_kind, Scalar): + if x_kind.is_real_valued: + ftype = BuiltinType("double") + else: + ftype = BuiltinType("double _Complex") + elif isinstance(x_kind, UserType): + ftype = code_generator.user_type_map[x_kind.identifier] + else: + raise TypeError("unsupported kind for isnan argument: %s" % x_kind) + + code_generator.emit("*{result} = false;".format(result=result)) + code_generator.emit("") + + IsNaNComputer(code_generator, result)(ftype, args[0], {}) + code_generator.emit("") + + +builtin_array = CallCode(""" + if (int(${n})!=${n}) + { + fprintf(stderr, "argument to array() is not an integer"); + } + + (*${result}).resize(int(${n})); + """) + + +UTIL_MACROS = """ + <%def name="write_matrix(mat_array, rows_var)" > + <% + i = declare_new("int", "i") + j = declare_new("int", "j") + a_cols = declare_new("int", "a_cols") + %> + + ${a_cols} = (*${a}).size() / int(${rows_var}); + + for (${i} = 0; ${i} < int(${rows_var})-1; ${i}++) + { + for (${j} = 0; ${j} < int(${a_cols})-1; ${j}++) + { + std::cout << ${mat_array}[${i}+${j}*${a_cols}] << std::endl; + } + } + + + <%def name="check_matrix(mat_array, cols_var, rows_var, func_name)" > + if (int(${cols_var}) != ${cols_var}) + { + std::cerr << "argument " << + "${cols_var}" << + " to " << "${func_name}" << + " is not an integer" << std::endl; + } + + ${rows_var} = (*${mat_array}).size() / int(${cols_var}); + + if (${rows_var} * int(${cols_var}) != (*${mat_array}).size()) + { + std::cerr << "size of argument " << + "${mat_array}" << + " to " << "${func_name}" << + " not divisible by " << + "${cols_var}" << std::endl; + } + + + <% + def get_lapack_letter(kind): + if kind.is_real_valued: + if real_scalar_kind == "4": + return "s" + elif real_scalar_kind == "8": + return "d" + else: + raise TypeError("unrecognized real kind %s" % real_scalar_kind) + else: + if complex_scalar_kind == "8": + return "c" + elif complex_scalar_kind == "16": + return "z" + else: + raise TypeError("unrecognized complex kind %s" + % complex_scalar_kind) + + def kind_to_cxx(kind): + if kind.is_real_valued: + if real_scalar_kind == "4": + return "std::vector" + elif real_scalar_kind == "8": + return "std::vector" + else: + return "double _Complex" + %> + + """ + + +builtin_matmul = CallCode(UTIL_MACROS + """ + <% + i = declare_new("int", "i") + j = declare_new("int", "j") + k = declare_new("int", "k") + a_rows = declare_new("int", "a_rows") + b_rows = declare_new("int", "b_rows") + res_size = declare_new("int", "res_size") + %> + + ${check_matrix(a, a_cols, a_rows, "matmul")} + ${check_matrix(b, b_cols, b_rows, "matmul")} + + ${a_rows} = (*${a}).size() / int(${a_cols}); + ${b_rows} = (*${b}).size() / int(${b_cols}); + + ${res_size} = ${a_rows} * int(${b_cols}); + + (*${result}).resize(${res_size}); + + for (int ${i} = 0; ${i} < ${b_cols}; ${i}++) + { + for (int ${j} = 0; ${j} < ${a_rows}; ${j}++) + { + /* Compute C(i,j) */ + double cij = (*${result})[int(${i}*${a_rows}+${j})]; + for( int ${k} = 0; ${k} < ${a_cols}; ${k}++ ) + { + cij += (*${a})[int(${j}+${a_rows}*${k})] + * (*${b})[int(${k}+${i}*${b_cols})]; + } + (*${result})[int(${i}*${a_rows}+${j})] = cij; + } + } + """) + + +builtin_transpose = CallCode(UTIL_MACROS + """ + <% + i = declare_new("int", "i") + j = declare_new("int", "j") + a_rows = declare_new("int", "a_rows") + res_size = declare_new("int", "res_size") + %> + + ${check_matrix(a, a_cols, a_rows, "transpose")} + + ${a_rows} = (*${a}).size() / int(${a_cols}); + ${res_size} = ${a_rows} * int(${a_cols}); + + (*${result}).resize(${res_size}); + + for(int ${i} = 0; ${i} < ${a_rows}; ${i}++) + { + for(int ${j} = 0; ${j} < ${a_cols}; ${j}++) + { + (*${result})[${a_cols} * ${i} + ${j}] = + (*${a})[${a_rows} * ${j} + ${i}]; + } + } + + """) + + +builtin_linear_solve = CallCode(UTIL_MACROS + """ + <% + res_size = declare_new("int", "res_size") + a_rows = declare_new("int", "a_rows") + b_rows = declare_new("int", "b_rows") + b_cols_int = declare_new("int", "b_cols_int") + + %> + + ${check_matrix(a, a_cols, a_rows, "linear_solve")} + ${check_matrix(b, b_cols, b_rows, "linear_solve")} + + if (int(${a_rows})!=int(${b_rows})) + { + std::cerr << "inconsistent matrix sizes in linear_solve" << std::endl; + } + if (int(${a_rows})!=int(${a_cols})) + { + std::cerr << "non-square matrix sizes in linear_solve" << std::endl; + } + + ${res_size} = int(${b_rows}) * int(${b_cols}); + + <% + if a_kind != b_kind: + raise TypeError("linear_solve requires both arguments " + "to have same kind") + + ltr = get_lapack_letter(a_kind) + + lu_temp = declare_new( + kind_to_cxx(a_kind) + , "lu_temp") + ipiv = declare_new("int*", "ipiv") + info = declare_new("int", "info") + %> + + ${lu_temp}.resize((*${a}).size()); + (*${result}).resize((*${b}).size()); + ${ipiv} = new int[${a_rows}]; + + ${lu_temp} = (*${a}); + (*${result}) = (*${b}); + ${b_cols_int} = int(${b_cols}); + + ${ltr}gesv_(&${a_rows}, &${b_cols_int}, + ${lu_temp}.data(), &${a_rows}, ${ipiv}, + ${result}->data(), &${b_rows}, &${info}); + + if (${info}!=0) + { + std::cerr << "gesv on " << "${a}" << + " failed with info = " << ${info} << std::endl; + } + + ${lu_temp}.clear(); + delete [] (${ipiv}); + + """) + + +builtin_svd = CallCode(UTIL_MACROS + """ + <% + sigma_size = declare_new("int", "res_size") + a_rows = declare_new("int", "a_rows") + + %> + + ${check_matrix(a, a_cols, a_rows, "svd")} + ${sigma_size} = std::min(int(${a_cols}),int(${a_rows})); + + <% + ltr = get_lapack_letter(a_kind); + + a_temp = declare_new( + kind_to_cxx(a_kind) + , "a_temp") + work = declare_new( + kind_to_cxx(a_kind) + , "work") + info = declare_new("int", "info") + lwork = declare_new("int", "lwork") + lda = declare_new("int", "lda") + ldu = declare_new("int", "ldu") + max1 = declare_new("int", "max1") + max2 = declare_new("int", "max2") + max3 = declare_new("int", "max3") + ldvt = declare_new("int", "ldvt") + a_cols_int = declare_new("int", "a_cols_int") + jobu = declare_new("char", "jobu") + jobvt = declare_new("char", "jobvt") + %> + + ${a_temp}.resize((*${a}).size()); + ${jobu} = 'S'; + ${jobvt} = 'S'; + ${lda} = std::max(1,int(${a_rows})); + ${ldu} = int(${a_rows}); + ${ldvt} = std::min(int(${a_rows}),int(${a_rows})); + + ${a_temp} = (*${a}); + ${max1} = 1; + ${max2} = 3 * std::min(int(${a_rows}), int(${a_cols})) + + std::max(int(${a_rows}), int(${a_cols})); + ${max3} = 5 * std::min(int(${a_rows}), int(${a_cols})); + //${lwork} = *std::max_element({${max1}, ${max2}, ${max3}}); + + if (${max1} >= ${max2} && ${max1} >= ${max3}) + { + ${lwork} = ${max1}; + } + else if (${max2} >= ${max1} && ${max2} >= ${max3}) + { + ${lwork} = ${max2}; + } + else + { + ${lwork} = ${max3}; + } + + (*${sigma}).resize(${sigma_size}); + ${work}.resize(${lwork}); + (*${u}).resize(int(${a_rows} * ${a_rows})); + (*${vt}).resize(int(${a_rows}*${a_cols})); + ${a_cols_int} = int(${a_cols}); + + ${ltr}gesvd_(&${jobu}, &${jobu}, &${a_rows}, &${a_cols_int}, + ${a_temp}.data(), &${lda}, ${sigma}->data(), ${u}->data(), + &${ldu}, ${vt}->data(), + &${ldvt}, ${work}.data(), &${lwork}, &${info}); + + if (${info}!=0) + { + std::cerr << "gesvd on " << ${a} + << " failed with info = " << ${info} << std::endl; + } + + ${a_temp}.clear(); + ${work}.clear(); + + """) + + +builtin_print = CallCode(UTIL_MACROS + """ + std::cout << ${arg} << std::endl; + """) + +# }}} + + +# vim: foldmethod=marker diff --git a/dagrt/codegen/expressions.py b/dagrt/codegen/expressions.py index 249abbc80d6e298585b285c6e764bbc534179b3f..69944a9c73ebb5aa94db3bdce5045b1881fe6e27 100644 --- a/dagrt/codegen/expressions.py +++ b/dagrt/codegen/expressions.py @@ -220,4 +220,108 @@ class PythonExpressionMapper(StringifyMapper): # }}} +# {{{ CXX + + +class CXXExpressionMapper(StringifyMapper): + """Converts expressions to CXX code.""" + + def __init__(self, name_manager): + """name_manager is a map from a variable name (as a string) to its + representation (as a string). + """ + super(CXXExpressionMapper, self).__init__(repr) + self._name_manager = name_manager + + def map_constant(self, expr, enclosing_prec): + if isinstance(expr, (complex, np.complex)): + return "(%s, %s)" % ( + self.rec(expr.real), + self.rec(expr.imag)) + elif isinstance(expr, bool): + if expr: + return ".true." + else: + return ".false." + else: + result = repr(expr) + if expr < 0: + result = "(%s)" % result + return result + + def map_foreign(self, expr, enclosing_prec): + if expr is None: + raise NotImplementedError() + elif isinstance(expr, str): + return repr(expr) + else: + return super(CXXExpressionMapper, self).map_foreign( + expr, enclosing_prec) + + TARGET_PREFIX = "" + + def map_variable(self, expr, enclosing_prec): + if expr.name.startswith(self.TARGET_PREFIX): + return expr.name[len(self.TARGET_PREFIX):] + else: + return self._name_manager[expr.name] + + def map_lookup(self, expr, enclosing_prec): + return self.parenthesize_if_needed( + self.format("%s.%s", + self.rec(expr.aggregate, PREC_CALL), + expr.name), + enclosing_prec, PREC_CALL) + + def map_subscript(self, expr, enclosing_prec): + if isinstance(expr.index, tuple): + index_str = ", ".join( + "%s" % self.rec(i, PREC_NONE) + for i in expr.index) + else: + index_str = "%s" % self.rec(expr.index, PREC_NONE) + + return self.parenthesize_if_needed( + self.format("%s[%s]", + self.rec(expr.aggregate, PREC_CALL), + index_str), + enclosing_prec, PREC_CALL) + + def map_product(self, expr, enclosing_prec, *args, **kwargs): + # This differs from the superclass only by adding spaces + # around the operator, which provide an opportunity for + # line breaking. + return self.parenthesize_if_needed( + self.join_rec(" * ", expr.children, PREC_PRODUCT, *args, **kwargs), + enclosing_prec, PREC_PRODUCT) + + def map_power(self, expr, enclosing_prec, *args, **kwargs): + from pymbolic.mapper.stringifier import PREC_POWER + return self.parenthesize_if_needed( + self.format("pow(%s, %s)", self.rec(expr.base, PREC_POWER), + self.rec(expr.exponent, PREC_POWER)), + enclosing_prec, PREC_POWER) + + def map_logical_not(self, expr, enclosing_prec): + from pymbolic.mapper.stringifier import PREC_UNARY + return self.parenthesize_if_needed( + "!" + self.rec(expr.child, PREC_UNARY), + enclosing_prec, PREC_UNARY) + + def map_logical_or(self, expr, enclosing_prec): + from pymbolic.mapper.stringifier import PREC_LOGICAL_OR + return self.parenthesize_if_needed( + self.join_rec( + " || ", expr.children, PREC_LOGICAL_OR), + enclosing_prec, PREC_LOGICAL_OR) + + def map_logical_and(self, expr, enclosing_prec): + from pymbolic.mapper.stringifier import PREC_LOGICAL_AND + return self.parenthesize_if_needed( + self.join_rec( + " && ", expr.children, PREC_LOGICAL_AND), + enclosing_prec, PREC_LOGICAL_AND) + +# }}} + # vim: foldmethod=marker diff --git a/dagrt/function_registry.py b/dagrt/function_registry.py index f8c8f5c76b0720aafc2e8ff6807ab6909f4b0be6..de8245a66f7eb4f00c34ea4118337f3b3276061e 100644 --- a/dagrt/function_registry.py +++ b/dagrt/function_registry.py @@ -580,6 +580,27 @@ def _make_bfr(): bfr = bfr.register_codegen(Print.identifier, "fortran", f.builtin_print) + import dagrt.codegen.cxx as cxx + + bfr = bfr.register_codegen(Norm2.identifier, "cxx", + cxx.codegen_builtin_norm_2) + bfr = bfr.register_codegen(Len.identifier, "cxx", + cxx.codegen_builtin_len) + bfr = bfr.register_codegen(IsNaN.identifier, "cxx", + cxx.codegen_builtin_isnan) + bfr = bfr.register_codegen(Array_.identifier, "cxx", + cxx.builtin_array) + bfr = bfr.register_codegen(MatMul.identifier, "cxx", + cxx.builtin_matmul) + bfr = bfr.register_codegen(Transpose.identifier, "cxx", + cxx.builtin_transpose) + bfr = bfr.register_codegen(LinearSolve.identifier, "cxx", + cxx.builtin_linear_solve) + bfr = bfr.register_codegen(SVD.identifier, "cxx", + cxx.builtin_svd) + bfr = bfr.register_codegen(Print.identifier, "cxx", + cxx.builtin_print) + return bfr diff --git a/dagrt/utils.py b/dagrt/utils.py index bd8df2e9fbec232b721e49eb4d59d66a4c0b2500..420e6867dd9cfa752b37b36039b0ccd9757a3f20 100644 --- a/dagrt/utils.py +++ b/dagrt/utils.py @@ -216,9 +216,57 @@ def run_fortran(sources, fortran_options=None, fortran_libraries=None): file=sys.stderr) if stderr_data: - raise RuntimeError( - "Fortran code has non-empty stderr:\n" - + stderr_data.decode('ascii')) + raise RuntimeError("Fortran code has non-empty stderr:\n" + + stderr_data.decode('ascii')) + + return p.returncode, stdout_data, stderr_data + +# }}} + + +# {{{ run_cxx + +def run_cxx(sources, cxx_options=None): + if cxx_options is None: + cxx_options = [] + + from os.path import join + + with TemporaryDirectory() as tmpdir: + source_names = [] + for name, contents in sources: + source_names.append(name) + + with open(join(tmpdir, name), "w") as srcf: + srcf.write(contents) + + import os + from subprocess import check_call, Popen, PIPE + check_call( + #[os.environ.get("CXX", "g++"), + # "-g", "-oruntest"] + #+ cxx_options + #+ [source_names[0]], + [os.environ.get("CXX", "g++"), + "-g", "-oruntest"] + + [source_names[0]] + + cxx_options, + cwd=tmpdir) + + p = Popen([join(tmpdir, "runtest")], stdout=PIPE, stderr=PIPE, + close_fds=True) + stdout_data, stderr_data = p.communicate() + + if stdout_data: + print("CXX code said this on stdout: -----------------------------", + file=sys.stderr) + print(stdout_data.decode(), file=sys.stderr) + print("---------------------------------------------------------------", + file=sys.stderr) + + if stderr_data: + raise RuntimeError("CXX code has non-empty stderr:\n" + + stderr_data.decode('ascii')) return p.returncode, stdout_data, stderr_data diff --git a/test/CXX_Testing/FC.h b/test/CXX_Testing/FC.h new file mode 100644 index 0000000000000000000000000000000000000000..467bfb657d634a3ac8d2773d8437cd23e96816ef --- /dev/null +++ b/test/CXX_Testing/FC.h @@ -0,0 +1,16 @@ +#ifndef FC_HEADER_INCLUDED +#define FC_HEADER_INCLUDED + +/* Mangling for Fortran global symbols without underscores. */ +#define FC_GLOBAL(name,NAME) name##_ + +/* Mangling for Fortran global symbols with underscores. */ +#define FC_GLOBAL_(name,NAME) name##_ + +/* Mangling for Fortran module symbols without underscores. */ +#define FC_MODULE(mod_name,name, mod_NAME,NAME) __##mod_name##_MOD_##name + +/* Mangling for Fortran module symbols with underscores. */ +#define FC_MODULE_(mod_name,name, mod_NAME,NAME) __##mod_name##_MOD_##name + +#endif diff --git a/test/CXX_Testing/fe_gen.py b/test/CXX_Testing/fe_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..0860fdc074553bda4ddcf161016247c18fab1d5a --- /dev/null +++ b/test/CXX_Testing/fe_gen.py @@ -0,0 +1,54 @@ +from __future__ import print_function + +import dagrt.codegen.cxx as cxx +from leap.rk import ODE45Method +from time_int_helpers import ( + t_cv_and_aux_type_descr_single_rate, + register_state_update_function_singlerate) + + +def main(): + component_id = 'y' + stepper = ODE45Method(component_id) + + from dagrt.function_registry import ( + base_function_registry, register_ode_rhs) + + freg = register_ode_rhs(base_function_registry, component_id) + freg = freg.register_codegen(""+component_id, "cxx", + cxx.CallCode(""" + + // Simple scalar RHS + + for (int i = 0; i < 101; i++) + { + ${result}[i] = rhs_val; + } + + // ${result} = rhs_val; + + """)) + + freg = register_state_update_function_singlerate(freg) + code = stepper.generate() + + codegen = cxx.CodeGenerator( + 'RK4', + user_type_map={ + component_id: t_cv_and_aux_type_descr_single_rate + }, + function_registry=freg, + call_after_state_update="notify_post_state_update", + extra_arguments="rhs_val", + extra_argument_types="double", + emit_instrumentation=True, + timing_function="clock") + + import sys + with open(sys.argv[1], "w") as outf: + code_str = codegen(code) + print(code_str, file=outf) + + +if __name__ == "__main__": + main() diff --git a/test/CXX_Testing/test_fe.c b/test/CXX_Testing/test_fe.c new file mode 100644 index 0000000000000000000000000000000000000000..c188061dbe5288926e3ad2e631fe4095910bf9f0 --- /dev/null +++ b/test/CXX_Testing/test_fe.c @@ -0,0 +1,44 @@ +// Forward Euler time integration of simple ODE dy/dt = a using new Leap/Dagrt code generation +#include +#include +#include +#include +#include +#include "test_header.H" +#include "FC.h" +using namespace std; + +int main() +{ + // Initialize dagrt state structure to pass to Leap code. + dagrt_state_type dagrt_state; + + // Initialize time, timestep, and state to be integrated. + double t = 0.00; + double dt = 0.01; + double rhs_val = 6.4; + std::shared_ptr t_ptr(new double); + std::shared_ptr dt_ptr(new double); + std::shared_ptr y_ptr(new double[101]); + + *t_ptr = t; + *dt_ptr = dt; + + for(int i = 0; i < 101; i++) + { + y_ptr.get()[i] = (double) i / 100; + } + + // Initialize Leap FE integrator + initialize(rhs_val, dagrt_state, dt_ptr, y_ptr, t_ptr); + for(int i = 0; i < 101; i++) + { + // Run Leap FE integrator + run(rhs_val, dagrt_state); + // cout << dagrt_state.dagrt_t << ' ' << dagrt_state.state_y.get()[0] << endl; + } + + print_profile(dagrt_state); + + return 0; +} diff --git a/test/CXX_Testing/test_header.H b/test/CXX_Testing/test_header.H new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/CXX_Testing/time_int_helpers.py b/test/CXX_Testing/time_int_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..fdba320dc33ed71476e86b710c78dbbf57c0c2c6 --- /dev/null +++ b/test/CXX_Testing/time_int_helpers.py @@ -0,0 +1,23 @@ +import dagrt.codegen.cxx as cxx + +t_cv_and_aux_type_descr_single_rate = cxx.ArrayType( + dimension=(101,), + index_vars='j', + element_type=cxx.BuiltinType('double')) + + +def register_state_update_function_singlerate(freg): + from dagrt.function_registry import register_function + + freg = register_function(freg, "notify_post_state_update", + ("updated_component",)) + freg = freg.register_codegen("notify_post_state_update", "cxx", + cxx.CallCode(""" + // cout << 'after state update' << endl; + + cout << "State Element 1" << + " " << dagrt_state.state_y.get()[0] << endl; + + """)) + + return freg