diff --git a/dagrt/codegen/fortran.py b/dagrt/codegen/fortran.py index f3766e2f62b6268649ee17707e205e9dbd0872b0..b9921e77b222f6d40a2b9b9e26129fdffd0d7bc6 100644 --- a/dagrt/codegen/fortran.py +++ b/dagrt/codegen/fortran.py @@ -1062,11 +1062,9 @@ class CodeGenerator(StructuredCodeGenerator): # }}} - from dagrt.data import SymbolKindFinder - - self.sym_kind_table = SymbolKindFinder(self.function_registry)( - [fd.name for fd in fdescrs], - [get_statements_in_ast(fd.ast) for fd in fdescrs]) + from dagrt.data import infer_kinds + dag = infer_kinds(dag, self.function_registry) + self.sym_kind_table = dag.sym_kind_table from dagrt.codegen.analysis import ( collect_ode_component_names_from_dag, @@ -1084,11 +1082,11 @@ class CodeGenerator(StructuredCodeGenerator): from dagrt.data import Scalar, UserType for comp_id in component_ids: - self.sym_kind_table.set( + self.sym_kind_table = self.sym_kind_table.set( None, ""+comp_id, Scalar(is_real_valued=True)) - self.sym_kind_table.set( + self.sym_kind_table = self.sym_kind_table.set( None, ""+comp_id, Scalar(is_real_valued=True)) - self.sym_kind_table.set( + self.sym_kind_table = self.sym_kind_table.set( None, ""+comp_id, UserType(comp_id)) self.begin_emit(dag) @@ -1634,7 +1632,8 @@ class CodeGenerator(StructuredCodeGenerator): for sym in init_symbols: sym_kind = self.sym_kind_table.global_table[sym] fortran_name = self.name_manager.name_global(sym) - self.sym_kind_table.set(phase_id, ""+fortran_name, sym_kind) + self.sym_kind_table = self.sym_kind_table.set( + phase_id, ""+fortran_name, sym_kind) self.declaration_emitter('type(dagrt_state_type), pointer :: dagrt_state') self.declaration_emitter('') diff --git a/dagrt/data.py b/dagrt/data.py index ebf4219e8eb28ae06757a51da0cd344269b3aa5f..9af9a2805014a105126859956e6509ff14ee9265 100644 --- a/dagrt/data.py +++ b/dagrt/data.py @@ -31,8 +31,9 @@ from dagrt.utils import TODO import six import dagrt.language as lang from dagrt.utils import is_state_variable -from pytools import RecordWithoutPickling +from pytools import ImmutableRecord from pymbolic.mapper import Mapper +from pyrsistent import pmap __doc__ = """ @@ -77,22 +78,9 @@ def _get_arg_dict_from_call_stmt(stmt): # {{{ symbol information -class SymbolKind(RecordWithoutPickling): +class SymbolKind(ImmutableRecord): """Base class for kinds encountered in the :mod:`dagrt` language.""" - - def __eq__(self, other): - return ( - type(self) == type(other) - and self.__getinitargs__() == other.__getinitargs__()) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash((type(self), self.__getinitargs__())) - - def __getinitargs__(self): - return () + pass class Boolean(SymbolKind): @@ -116,9 +104,6 @@ class Scalar(SymbolKind): def __init__(self, is_real_valued): super(Scalar, self).__init__(is_real_valued=is_real_valued) - def __getinitargs__(self): - return (self.is_real_valued,) - class Array(SymbolKind): """A variable-sized one-dimensional scalar array. @@ -131,9 +116,6 @@ class Array(SymbolKind): def __init__(self, is_real_valued): super(Array, self).__init__(is_real_valued=is_real_valued) - def __getinitargs__(self): - return (self.is_real_valued,) - class UserType(SymbolKind): """Represents user state belonging to a normed vector space. @@ -146,15 +128,14 @@ class UserType(SymbolKind): def __init__(self, identifier): super(UserType, self).__init__(identifier=identifier) - def __getinitargs__(self): - return (self.identifier,) - # }}} -class SymbolKindTable(object): +class SymbolKindTable(ImmutableRecord): """A mapping from symbol names to kinds for a program. + Tables support a read-only mapping interface. + .. attribute:: global_table a mapping from symbol names to :class:`SymbolKind` instances, @@ -164,52 +145,66 @@ class SymbolKindTable(object): a nested mapping ``[phase_name][symbol_name]`` to :class:`SymbolKind` instances + + .. automethod:: get + .. automethod:: set """ - def __init__(self): - self.global_table = { - "": Scalar(is_real_valued=True), - "
": Scalar(is_real_valued=True), - } - self.per_phase_table = {} - self._changed = False + def __init__(self, global_table=None, per_phase_table=None): + if global_table is None: + global_table = pmap({ + "": Scalar(is_real_valued=True), + "
": Scalar(is_real_valued=True), + }) - def reset_change_flag(self): - self._changed = False + if per_phase_table is None: + per_phase_table = pmap() - def is_changed(self): - return self._changed + super(SymbolKindTable, self).__init__( + global_table=global_table, + per_phase_table=per_phase_table) def set(self, phase_name, name, kind): + """Assign a kind to a symbol. + + If *name* is present in the table, the new kind is the result of + unifying *kind* with the kind present in the table. + + :arg phase_name: the name of the associated phase, if any + :arg name: the symbol name + :arg kind: the new kind for *name* + + :returns: a :class:`SymbolKindTable` + """ if is_state_variable(name): tbl = self.global_table else: - tbl = self.per_phase_table.setdefault(phase_name, {}) + tbl = self.per_phase_table.get(phase_name, pmap()) if name in tbl: if tbl[name] != kind: - try: - kind = unify(kind, tbl[name]) - except Exception: - print( - "trying to derive 'kind' for '%s' in " - "'%s': '%s' vs '%s'" - % (name, phase_name, - repr(kind), - repr(tbl[name]))) - else: - if tbl[name] != kind: - self._changed = True - tbl[name] = kind + kind = unify(kind, tbl[name]) + + tbl = tbl.set(name, kind) + if is_state_variable(name): + return self.copy(global_table=tbl) else: - tbl[name] = kind + per_phase_table = self.per_phase_table.set(phase_name, tbl) + return self.copy(per_phase_table=per_phase_table) def get(self, phase_name, name): + """Look up the kind of a symbol. + + :arg phase_name: the name of the associated phase, if any + :arg name: the symbol name + + :returns: a :class:`SymbolKind` + """ if is_state_variable(name): tbl = self.global_table else: - tbl = self.per_phase_table.setdefault(phase_name, {}) + tbl = self.per_phase_table.get(phase_name, pmap()) return tbl[name] @@ -286,11 +281,14 @@ class KindInferenceMapper(Mapper): """ .. attribute:: global_table - The :class:`SymbolKindTable` for the global scope. + A mapping with symbol kind information for the global scope, such as + a :attr:`SymbolKindTable.global_table` .. attribute:: local_table - The :class:`SymbolKindTable` for the phase currently being processed. + A mapping with symbol kind information for symbols local to the phase + being processed, such as a value in + a :attr:`SymbolKindTable.per_phase_table` """ def __init__(self, global_table, local_table, function_registry, check): @@ -446,7 +444,7 @@ class SymbolKindFinder(object): def __init__(self, function_registry): self.function_registry = function_registry - def __call__(self, names, phases): + def __call__(self, names, phases, initial_table=None): """Infer the kinds of all the symbols in a program. :arg names: a list of phase names @@ -463,7 +461,7 @@ class SymbolKindFinder(object): expanded_phases.append(list(phase)) phases = expanded_phases - result = SymbolKindTable() + result = initial_table def make_kim(phase_name, check): return KindInferenceMapper( @@ -479,8 +477,7 @@ class SymbolKindFinder(object): stmt_queue_push_buffer = [] made_progress = False - - result.reset_change_flag() + old_result = result while stmt_queue or stmt_queue_push_buffer: if not stmt_queue: @@ -529,7 +526,7 @@ class SymbolKindFinder(object): kim = make_kim(phase_name, check=False) for ident, _, _ in stmt.loops: - result.set(phase_name, ident, kind=Integer()) + result = result.set(phase_name, ident, kind=Integer()) if stmt.assignee_subscript: continue @@ -540,7 +537,7 @@ class SymbolKindFinder(object): stmt_queue_push_buffer.append((phase_name, stmt)) else: made_progress = True - result.set(phase_name, stmt.assignee, kind=kind) + result = result.set(phase_name, stmt.assignee, kind=kind) elif isinstance(stmt, lang.AssignFunctionCall): kim = make_kim(phase_name, check=False) @@ -554,7 +551,7 @@ class SymbolKindFinder(object): else: made_progress = True for assignee, kind in zip(stmt.assignees, kinds): - result.set(phase_name, assignee, kind=kind) + result = result.set(phase_name, assignee, kind=kind) elif isinstance(stmt, lang.AssignmentBase): raise TODO() @@ -563,7 +560,7 @@ class SymbolKindFinder(object): # We only care about assignments. pass - if not result.is_changed(): + if result == old_result: break # {{{ check consistency of obtained kinds @@ -610,7 +607,7 @@ def infer_kinds(dag, function_registry=None): :arg dag: a :class:`dagrt.language.DAGCode` :arg function_registry: if not *None*, the function registry to use - :returns: a :class:`SymbolKindTable` + :returns: an updated :class:`DAGCode` """ if function_registry is None: from dagrt.function_registry import base_function_registry @@ -620,7 +617,9 @@ def infer_kinds(dag, function_registry=None): names = list(dag.phases) phases = [phase.statements for phase in dag.phases.values()] - return kind_finder(names, phases) + sym_kind_table = kind_finder(names, phases, dag.sym_kind_table) + + return dag.copy(sym_kind_table=sym_kind_table) # }}} diff --git a/dagrt/language.py b/dagrt/language.py index 4603d7bd1d664ebfebeb015d60b7d89dc92d1c88..1f47564a4dfc57a9d8f9352229558e094006f2b9 100644 --- a/dagrt/language.py +++ b/dagrt/language.py @@ -649,22 +649,30 @@ class DAGCode(RecordWithoutPickling): .. attribute:: initial_phase the name of the starting phase + + .. attribute:: sym_kind_table + + a :class:`dagrt.data.SymbolKindTable` with symbol kind information """ @classmethod - def from_phases_list(cls, phases, initial_phase): + def from_phases_list(cls, phases, initial_phase, sym_kind_table=None): name_to_phase = dict() for phase in phases: if phase.name in name_to_phase: raise ValueError("duplicate phase name '%s'" % phase.name) name_to_phase[phase.name] = phase - return cls(name_to_phase, initial_phase) + return cls(name_to_phase, initial_phase, sym_kind_table) - def __init__(self, phases, initial_phase): + def __init__(self, phases, initial_phase, sym_kind_table=None): + if sym_kind_table is None: + from dagrt.data import SymbolKindTable + sym_kind_table = SymbolKindTable() assert not isinstance(phases, list) RecordWithoutPickling.__init__(self, phases=phases, - initial_phase=initial_phase) + initial_phase=initial_phase, + sym_kind_table=sym_kind_table) # {{{ identifier wrangling @@ -692,6 +700,7 @@ class DAGCode(RecordWithoutPickling): def __str__(self): lines = [] + lines.append("===== Phases =====") for phase_name, phase in sorted(six.iteritems(self.phases)): phase_title = "PHASE \"%s\"" % phase_name if phase_name == self.initial_phase: @@ -706,8 +715,10 @@ class DAGCode(RecordWithoutPickling): lines.append(" -> (next phase) \"%s\"" % phase.next_phase) lines.append("") - return "\n".join(lines) + lines.append("===== Known Symbol Kinds =====") + lines.append(str(self.sym_kind_table)) + return "\n".join(lines) # }}} diff --git a/setup.py b/setup.py index 706c6ed13389d9f2e1997a4035a542bffaf955d8..2000b06a648797dfaf4773d6af8a1b9f5480d684 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ def main(): "pytest>=2.3", "mako", "six", + "pyrsistent", ], ) diff --git a/test/test_kind_inference.py b/test/test_kind_inference.py new file mode 100755 index 0000000000000000000000000000000000000000..4b753e49b573ea13629f8d70ffbc32c16d0a4613 --- /dev/null +++ b/test/test_kind_inference.py @@ -0,0 +1,113 @@ +#! /usr/bin/env python +from __future__ import division, with_statement + +import sys + +from dagrt.expression import parse +from dagrt.language import CodeBuilder +from dagrt.data import infer_kinds, Scalar, Boolean, Integer, Array, UserType +from dagrt.function_registry import base_function_registry, register_ode_rhs + +from utils import create_DAGCode_with_steady_phase + + +__copyright__ = "Copyright (C) 2020 Matt Wala" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +def test_kind_inference_scalar(): + with CodeBuilder("main") as cb: + cb("x", 1.) + cb("y", 1.j) + + code = create_DAGCode_with_steady_phase(cb.statements) + code = infer_kinds(code) + + kinds = code.sym_kind_table.per_phase_table["main"] + + assert kinds["x"] == Scalar(is_real_valued=True) + assert kinds["y"] == Scalar(is_real_valued=False) + + +def test_kind_inference_boolean(): + with CodeBuilder("main") as cb: + cb("flag1", parse("1 > 2")) + cb("flag2", parse("not flag1")) + cb("flag3", parse("flag1 and flag2")) + cb("flag4", parse("flag1 or flag2")) + + code = create_DAGCode_with_steady_phase(cb.statements) + code = infer_kinds(code) + + kinds = code.sym_kind_table.per_phase_table["main"] + + assert kinds["flag1"] == Boolean() + assert kinds["flag2"] == Boolean() + assert kinds["flag3"] == Boolean() + assert kinds["flag4"] == Boolean() + + +def test_kind_inference_integer(): + with CodeBuilder("main") as cb: + cb("arr", "array(10)") + cb("arr[i]", 1., [("i", 0, 10)]) + + code = create_DAGCode_with_steady_phase(cb.statements) + code = infer_kinds(code) + + kinds = code.sym_kind_table.per_phase_table["main"] + assert kinds["i"] == Integer() + + +def test_kind_inference_array(): + with CodeBuilder("main") as cb: + cb("arr", "array(10)") + cb("arr", 1) + cb("arr_complex", "array(10)") + cb("arr_complex", 1j) + + code = create_DAGCode_with_steady_phase(cb.statements) + code = infer_kinds(code) + + kinds = code.sym_kind_table.per_phase_table["main"] + assert kinds["arr"] == Array(is_real_valued=True) + assert kinds["arr_complex"] == Array(is_real_valued=False) + + +def test_kind_inference_user_type(): + with CodeBuilder("main") as cb: + cb("y", parse("y +
* f(, y)")) + + code = create_DAGCode_with_steady_phase(cb.statements) + freg = register_ode_rhs(base_function_registry, "y", identifier="f") + code = infer_kinds(code, freg) + + kinds = code.sym_kind_table.global_table + assert kinds["y"] == UserType("y") + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__])