diff --git a/dagrt/codegen/fortran.py b/dagrt/codegen/fortran.py index 496891ffbcce3f53160c83ba3e185fb6c1fa44ef..1b2f8e30cb63ad2e502a6a9c5a0f27fa9c4e16d6 100644 --- a/dagrt/codegen/fortran.py +++ b/dagrt/codegen/fortran.py @@ -31,7 +31,7 @@ import six from dagrt.codegen.expressions import FortranExpressionMapper from dagrt.codegen.codegen_base import StructuredCodeGenerator from dagrt.utils import is_state_variable -from dagrt.codegen.data import UserType +from dagrt.data import UserType from pytools.py_codegen import ( # It's the same code. So sue me. PythonCodeGenerator as FortranEmitterBase) @@ -236,7 +236,7 @@ class CallCode(object): code_generator.declaration_emitter(decl_without_name + " :: " + new_name) return new_name - import dagrt.codegen.data as kinds + import dagrt.data as kinds template_names = dict( real_scalar_kind=code_generator.real_scalar_kind, @@ -1062,11 +1062,11 @@ class CodeGenerator(StructuredCodeGenerator): # }}} - from dagrt.codegen.data import SymbolKindFinder + from dagrt.data import SymbolKindFinder self.sym_kind_table = SymbolKindFinder(self.function_registry)( [fd.name for fd in fdescrs], - [fd.ast for fd in fdescrs]) + [get_statements_in_ast(fd.ast) for fd in fdescrs]) from dagrt.codegen.analysis import ( collect_ode_component_names_from_dag, @@ -1082,7 +1082,7 @@ class CodeGenerator(StructuredCodeGenerator): raise RuntimeError("User type missing from user type map: %r" % (component_ids - set(self.user_type_map))) - from dagrt.codegen.data import Scalar, UserType + from dagrt.data import Scalar, UserType for comp_id in component_ids: self.sym_kind_table.set( None, ""+comp_id, Scalar(is_real_valued=True)) @@ -1151,7 +1151,7 @@ class CodeGenerator(StructuredCodeGenerator): # {{{ emit variable deinit - sym_table = self.sym_kind_table.per_function_table.get( + sym_table = self.sym_kind_table.per_phase_table.get( self.current_function, {}) for identifier, sym_kind in sorted(six.iteritems(sym_table)): @@ -1305,7 +1305,7 @@ class CodeGenerator(StructuredCodeGenerator): # {{{ memory management routines - from dagrt.codegen.data import collect_user_types + from dagrt.data import collect_user_types user_types = collect_user_types(self.sym_kind_table) # {{{ allocation checks @@ -1430,11 +1430,11 @@ class CodeGenerator(StructuredCodeGenerator): if emit is None: emit = self.emit - from dagrt.codegen.data import Boolean, Scalar, Array, Integer + from dagrt.data import Boolean, Scalar, Array, Integer type_specifiers = other_specifiers - from dagrt.codegen.data import UserType + from dagrt.data import UserType if isinstance(sym_kind, Boolean): type_name = 'logical' @@ -1482,7 +1482,7 @@ class CodeGenerator(StructuredCodeGenerator): id=fortran_name)) def emit_variable_init(self, name, sym_kind): - from dagrt.codegen.data import UserType + from dagrt.data import UserType if isinstance(sym_kind, UserType): ftype = self.get_fortran_type_for_user_type(sym_kind.identifier) InitializationEmitter(self)(ftype, self.name_manager[name], {}) @@ -1491,7 +1491,7 @@ class CodeGenerator(StructuredCodeGenerator): fortran_name = self.name_manager[name] refcnt_name = self.name_manager.name_refcount(name) - from dagrt.codegen.data import UserType + from dagrt.data import UserType if not isinstance(sym_kind, UserType): return @@ -1508,7 +1508,7 @@ class CodeGenerator(StructuredCodeGenerator): def emit_refcounted_allocation(self, sym, sym_kind): fortran_name = self.name_manager[sym] - from dagrt.codegen.data import UserType + from dagrt.data import UserType if not isinstance(sym_kind, UserType): return @@ -1583,7 +1583,7 @@ class CodeGenerator(StructuredCodeGenerator): subscript_str=subscript_str, expr=str(expr)[:50])) - from dagrt.codegen.data import UserType + from dagrt.data import UserType if not isinstance(sym_kind, UserType): self.emit( "{name}{subscript_str} = {expr}" @@ -1677,7 +1677,7 @@ class CodeGenerator(StructuredCodeGenerator): # All our scalars are floating-point numbers for now, # so initializing them all to NaN is fine. - from dagrt.codegen.data import Scalar + from dagrt.data import Scalar if sym.startswith("": Scalar(is_real_valued=True), "
": Scalar(is_real_valued=True), } - self.per_function_table = {} + self.per_phase_table = {} self._changed = False def reset_change_flag(self): @@ -140,11 +180,11 @@ class SymbolKindTable(object): def is_changed(self): return self._changed - def set(self, func_name, name, kind): + def set(self, phase_name, name, kind): if is_state_variable(name): tbl = self.global_table else: - tbl = self.per_function_table.setdefault(func_name, {}) + tbl = self.per_phase_table.setdefault(phase_name, {}) if name in tbl: if tbl[name] != kind: @@ -154,7 +194,7 @@ class SymbolKindTable(object): print( "trying to derive 'kind' for '%s' in " "'%s': '%s' vs '%s'" - % (name, func_name, + % (name, phase_name, repr(kind), repr(tbl[name]))) else: @@ -165,11 +205,11 @@ class SymbolKindTable(object): else: tbl[name] = kind - def get(self, func_name, name): + def get(self, phase_name, name): if is_state_variable(name): tbl = self.global_table else: - tbl = self.per_function_table.setdefault(func_name, {}) + tbl = self.per_phase_table.setdefault(phase_name, {}) return tbl[name] @@ -181,8 +221,8 @@ class SymbolKindTable(object): return "\n".join( ["global:\n%s" % format_table(self.global_table)] + [ - "func '%s':\n%s" % (func_name, format_table(tbl)) - for func_name, tbl in self.per_function_table.items()]) + "phase '%s':\n%s" % (phase_name, format_table(tbl)) + for phase_name, tbl in self.per_phase_table.items()]) # {{{ kind inference mapper @@ -250,8 +290,7 @@ class KindInferenceMapper(Mapper): .. attribute:: local_table - The :class:`SymbolKindTable` for the :class:`dagrt.ir.Function` - currently being processed. + The :class:`SymbolKindTable` for the phase currently being processed. """ def __init__(self, global_table, local_table, function_registry, check): @@ -401,30 +440,42 @@ class KindInferenceMapper(Mapper): # {{{ symbol kind finder class SymbolKindFinder(object): + """ + .. automethod:: __call__ + """ def __init__(self, function_registry): self.function_registry = function_registry - def __call__(self, names, functions): - """Return a :class:`SymbolKindTable`. + def __call__(self, names, phases): + """Infer the kinds of all the symbols in a program. + + :arg names: a list of phase names + :arg phases: a list of iterables, each yielding the statements in a + phase + + :returns: a :class:`SymbolKindTable` + + :raises UnableToInferKind: kind inference could not complete sucessfully """ - result = SymbolKindTable() + expanded_phases = [] + for phase in phases: + expanded_phases.append(list(phase)) + phases = expanded_phases - from dagrt.codegen.dag_ast import get_statements_in_ast + result = SymbolKindTable() - def make_kim(func_name, check): + def make_kim(phase_name, check): return KindInferenceMapper( result.global_table, - result.per_function_table.get(func_name, {}), + result.per_phase_table.get(phase_name, {}), self.function_registry, check=False) while True: stmt_queue = [] - for name, func in zip(names, functions): - stmt_queue.extend( - (name, stmt) - for stmt in get_statements_in_ast(func)) + for name, phase in zip(names, phases): + stmt_queue.extend((name, stmt) for stmt in phase) stmt_queue_push_buffer = [] made_progress = False @@ -437,10 +488,10 @@ class SymbolKindFinder(object): if not made_progress: print("Left-over statements in kind inference:") - for func_name, stmt in stmt_queue_push_buffer: - print("[%s] %s" % (func_name, stmt)) + for phase_name, stmt in stmt_queue_push_buffer: + print("[%s] %s" % (phase_name, stmt)) - kim = make_kim(func_name, check=False) + kim = make_kim(phase_name, check=False) try: if isinstance(stmt, lang.Assign): @@ -472,13 +523,13 @@ class SymbolKindFinder(object): stmt_queue_push_buffer = [] made_progress = False - func_name, stmt = stmt_queue.pop() + phase_name, stmt = stmt_queue.pop() if isinstance(stmt, lang.Assign): - kim = make_kim(func_name, check=False) + kim = make_kim(phase_name, check=False) for ident, _, _ in stmt.loops: - result.set(func_name, ident, kind=Integer()) + result.set(phase_name, ident, kind=Integer()) if stmt.assignee_subscript: continue @@ -486,24 +537,24 @@ class SymbolKindFinder(object): try: kind = kim(stmt.expression) except UnableToInferKind: - stmt_queue_push_buffer.append((func_name, stmt)) + stmt_queue_push_buffer.append((phase_name, stmt)) else: made_progress = True - result.set(func_name, stmt.assignee, kind=kind) + result.set(phase_name, stmt.assignee, kind=kind) elif isinstance(stmt, lang.AssignFunctionCall): - kim = make_kim(func_name, check=False) + kim = make_kim(phase_name, check=False) try: kinds = kim.map_generic_call(stmt.function_id, _get_arg_dict_from_call_stmt(stmt), single_return_only=False) except UnableToInferKind: - stmt_queue_push_buffer.append((func_name, stmt)) + stmt_queue_push_buffer.append((phase_name, stmt)) else: made_progress = True for assignee, kind in zip(stmt.assignees, kinds): - result.set(func_name, assignee, kind=kind) + result.set(phase_name, assignee, kind=kind) elif isinstance(stmt, lang.AssignmentBase): raise TODO() @@ -517,10 +568,10 @@ class SymbolKindFinder(object): # {{{ check consistency of obtained kinds - for func_name, func in zip(names, functions): - kim = make_kim(func_name, check=True) + for phase_name, phase in zip(names, phases): + kim = make_kim(phase_name, check=True) - for stmt in get_statements_in_ast(func): + for stmt in phase: if isinstance(stmt, lang.Assign): kim(stmt.expression) @@ -551,16 +602,44 @@ class SymbolKindFinder(object): # }}} +# {{{ infer kinds of a DAGCode object + +def infer_kinds(dag, function_registry=None): + """Run kind inference on a :class:`dagrt.language.DAGCode`. + + :arg dag: a :class:`dagrt.language.DAGCode` + :arg function_registry: if not *None*, the function registry to use + + :returns: a :class:`SymbolKindTable` + """ + if function_registry is None: + from dagrt.function_registry import base_function_registry + function_registry = base_function_registry + + kind_finder = SymbolKindFinder(function_registry) + names = list(dag.phases) + phases = [phase.statements for phase in dag.phases.values()] + + return kind_finder(names, phases) + +# }}} + + # {{{ collect user types def collect_user_types(skt): + """Collect all of the of :class:`UserType` identifiers in a table. + + :arg skt: a :class:`SymbolKindTable` + :returns: a set of strings + """ result = set() for kind in six.itervalues(skt.global_table): if isinstance(kind, UserType): result.add(kind.identifier) - for tbl in six.itervalues(skt.per_function_table): + for tbl in six.itervalues(skt.per_phase_table): for kind in six.itervalues(tbl): if isinstance(kind, UserType): result.add(kind.identifier) diff --git a/dagrt/function_registry.py b/dagrt/function_registry.py index e084e7c05386ab18cc6b7326fd8c874cb36e0256..f8c8f5c76b0720aafc2e8ff6807ab6909f4b0be6 100644 --- a/dagrt/function_registry.py +++ b/dagrt/function_registry.py @@ -27,7 +27,7 @@ THE SOFTWARE. from pytools import RecordWithoutPickling -from dagrt.codegen.data import ( +from dagrt.data import ( UserType, Integer, Boolean, Scalar, Array, UnableToInferKind) NoneType = type(None) @@ -119,7 +119,7 @@ class Function(RecordWithoutPickling): **kwargs) def get_result_kinds(self, arg_kinds, check): - """Return a tuple of the :class:`dagrt.codegen.data.SymbolKind` + """Return a tuple of the :class:`dagrt.data.SymbolKind` instances for the values this function returns if arguments of the kinds *arg_kinds* are supplied. @@ -128,7 +128,7 @@ class Function(RecordWithoutPickling): :arg arg_kinds: a dictionary mapping numbers (for positional arguments) or identifiers (for keyword arguments) to - :class:`dagrt.codegen.data.SymbolKind` instances indicating the + :class:`dagrt.data.SymbolKind` instances indicating the types of the arguments being passed to the function. Some elements of *arg_kinds* may be None if their kinds have yet not been determined. @@ -696,7 +696,7 @@ def register_function( :arg default_dict: a dictionary mapping argument names to default values :arg result_names: a list of strings, the names of the output(s) - :arg result_kinds: a list of :class:`dagrt.codegen.data.SymbolKinds`, + :arg result_kinds: a list of :class:`dagrt.data.SymbolKinds`, the kinds of the output(s) :returns: a new :class:`FunctionRegistry` diff --git a/doc/reference.rst b/doc/reference.rst index 61a432c3202147afd0a64c5342efaba8c01d3de4..fe51d643dc6966b67f0b3a31278e1c26140057a1 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -30,6 +30,11 @@ Function registry .. automodule:: dagrt.function_registry +Data +~~~~~~~~~~~~~~~~~ + +.. automodule:: dagrt.data + Transformations ~~~~~~~~~~~~~~~