From 2baf5a665cbd5cde95aba1a3e1057959ca2f75ef Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Wed, 5 Feb 2020 23:20:47 -0600 Subject: [PATCH 1/9] Add a SymbolKindTable to DAGCode, make infer_kinds() a transformation This adds a new attribute called *sym_kind_table* to DAGCode which is a SymbolKindTable. The function *infer_kinds* is changed to be a transformation that returns a modified DAGCode object with all symbol kinds inferred. Since DAGCode is immutable, SymbolKindTable is changed to be immutable as well. The interface is changed to make SymbolKindTable free of side effects. To maintain efficiency, the underlying dicts are replaced with persistent dictionaries provided by pyrsistent. This change also makes it an error when unification fails in SymbolKindTable.set() Closes #35 by providing an interface to supply kind information when kind inference won't work --- dagrt/codegen/fortran.py | 17 +++--- dagrt/data.py | 127 +++++++++++++++++++-------------------- dagrt/language.py | 17 ++++-- setup.py | 1 + 4 files changed, 82 insertions(+), 80 deletions(-) diff --git a/dagrt/codegen/fortran.py b/dagrt/codegen/fortran.py index f3766e2..b9921e7 100644 --- a/dagrt/codegen/fortran.py +++ b/dagrt/codegen/fortran.py @@ -1062,11 +1062,9 @@ class CodeGenerator(StructuredCodeGenerator): # }}} - from dagrt.data import SymbolKindFinder - - self.sym_kind_table = SymbolKindFinder(self.function_registry)( - [fd.name for fd in fdescrs], - [get_statements_in_ast(fd.ast) for fd in fdescrs]) + from dagrt.data import infer_kinds + dag = infer_kinds(dag, self.function_registry) + self.sym_kind_table = dag.sym_kind_table from dagrt.codegen.analysis import ( collect_ode_component_names_from_dag, @@ -1084,11 +1082,11 @@ class CodeGenerator(StructuredCodeGenerator): from dagrt.data import Scalar, UserType for comp_id in component_ids: - self.sym_kind_table.set( + self.sym_kind_table = self.sym_kind_table.set( None, ""+comp_id, Scalar(is_real_valued=True)) - self.sym_kind_table.set( + self.sym_kind_table = self.sym_kind_table.set( None, ""+comp_id, Scalar(is_real_valued=True)) - self.sym_kind_table.set( + self.sym_kind_table = self.sym_kind_table.set( None, ""+comp_id, UserType(comp_id)) self.begin_emit(dag) @@ -1634,7 +1632,8 @@ class CodeGenerator(StructuredCodeGenerator): for sym in init_symbols: sym_kind = self.sym_kind_table.global_table[sym] fortran_name = self.name_manager.name_global(sym) - self.sym_kind_table.set(phase_id, ""+fortran_name, sym_kind) + self.sym_kind_table = self.sym_kind_table.set( + phase_id, ""+fortran_name, sym_kind) self.declaration_emitter('type(dagrt_state_type), pointer :: dagrt_state') self.declaration_emitter('') diff --git a/dagrt/data.py b/dagrt/data.py index ebf4219..fe90024 100644 --- a/dagrt/data.py +++ b/dagrt/data.py @@ -31,8 +31,9 @@ from dagrt.utils import TODO import six import dagrt.language as lang from dagrt.utils import is_state_variable -from pytools import RecordWithoutPickling +from pytools import ImmutableRecord from pymbolic.mapper import Mapper +from pyrsistent import pmap __doc__ = """ @@ -77,22 +78,9 @@ def _get_arg_dict_from_call_stmt(stmt): # {{{ symbol information -class SymbolKind(RecordWithoutPickling): +class SymbolKind(ImmutableRecord): """Base class for kinds encountered in the :mod:`dagrt` language.""" - - def __eq__(self, other): - return ( - type(self) == type(other) - and self.__getinitargs__() == other.__getinitargs__()) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash((type(self), self.__getinitargs__())) - - def __getinitargs__(self): - return () + pass class Boolean(SymbolKind): @@ -116,9 +104,6 @@ class Scalar(SymbolKind): def __init__(self, is_real_valued): super(Scalar, self).__init__(is_real_valued=is_real_valued) - def __getinitargs__(self): - return (self.is_real_valued,) - class Array(SymbolKind): """A variable-sized one-dimensional scalar array. @@ -131,9 +116,6 @@ class Array(SymbolKind): def __init__(self, is_real_valued): super(Array, self).__init__(is_real_valued=is_real_valued) - def __getinitargs__(self): - return (self.is_real_valued,) - class UserType(SymbolKind): """Represents user state belonging to a normed vector space. @@ -146,15 +128,14 @@ class UserType(SymbolKind): def __init__(self, identifier): super(UserType, self).__init__(identifier=identifier) - def __getinitargs__(self): - return (self.identifier,) - # }}} -class SymbolKindTable(object): +class SymbolKindTable(ImmutableRecord): """A mapping from symbol names to kinds for a program. + Tables support a read-only mapping interface. + .. attribute:: global_table a mapping from symbol names to :class:`SymbolKind` instances, @@ -166,50 +147,61 @@ class SymbolKindTable(object): instances """ - def __init__(self): - self.global_table = { - "": Scalar(is_real_valued=True), - "
": Scalar(is_real_valued=True), - } - self.per_phase_table = {} - self._changed = False + def __init__(self, global_table=None, per_phase_table=None): + if global_table is None: + global_table = pmap({ + "": Scalar(is_real_valued=True), + "
": Scalar(is_real_valued=True), + }) - def reset_change_flag(self): - self._changed = False + if per_phase_table is None: + per_phase_table = pmap() - def is_changed(self): - return self._changed + super(SymbolKindTable, self).__init__( + global_table=global_table, + per_phase_table=per_phase_table) def set(self, phase_name, name, kind): + """Assign a kind to a symbol. + + If *name* is present in the table, the new kind is the result of + unifying *kind* with the kind present in the table. + + :arg phase_name: the name of the associated phase, if any + :arg name: the symbol name + :arg kind: the new kind for *name* + + :returns: a :class:`SymbolKindTable` + """ if is_state_variable(name): tbl = self.global_table else: - tbl = self.per_phase_table.setdefault(phase_name, {}) + tbl = self.per_phase_table.get(phase_name, pmap()) if name in tbl: if tbl[name] != kind: - try: - kind = unify(kind, tbl[name]) - except Exception: - print( - "trying to derive 'kind' for '%s' in " - "'%s': '%s' vs '%s'" - % (name, phase_name, - repr(kind), - repr(tbl[name]))) - else: - if tbl[name] != kind: - self._changed = True - tbl[name] = kind + kind = unify(kind, tbl[name]) + + tbl = tbl.set(name, kind) + if is_state_variable(name): + return self.copy(global_table=tbl) else: - tbl[name] = kind + per_phase_table = self.per_phase_table.set(phase_name, tbl) + return self.copy(per_phase_table=per_phase_table) def get(self, phase_name, name): + """Look up the kind of a symbol. + + :arg phase_name: the name of the associated phase, if any + :arg name: the symbol name + + :returns: a :class:`SymbolKind` + """ if is_state_variable(name): tbl = self.global_table else: - tbl = self.per_phase_table.setdefault(phase_name, {}) + tbl = self.per_phase_table.get(phase_name, pmap()) return tbl[name] @@ -286,11 +278,14 @@ class KindInferenceMapper(Mapper): """ .. attribute:: global_table - The :class:`SymbolKindTable` for the global scope. + A mapping with symbol kind information for the global scope, such as + a :attr:`SymbolKindTable.global_table` .. attribute:: local_table - The :class:`SymbolKindTable` for the phase currently being processed. + A mapping with symbol kind information for symbols local to the phase + being processed, such as a value in + a :attr:`SymbolKindTable.per_phase_table` """ def __init__(self, global_table, local_table, function_registry, check): @@ -446,7 +441,7 @@ class SymbolKindFinder(object): def __init__(self, function_registry): self.function_registry = function_registry - def __call__(self, names, phases): + def __call__(self, names, phases, initial_table=None): """Infer the kinds of all the symbols in a program. :arg names: a list of phase names @@ -463,7 +458,7 @@ class SymbolKindFinder(object): expanded_phases.append(list(phase)) phases = expanded_phases - result = SymbolKindTable() + result = initial_table def make_kim(phase_name, check): return KindInferenceMapper( @@ -478,9 +473,7 @@ class SymbolKindFinder(object): stmt_queue.extend((name, stmt) for stmt in phase) stmt_queue_push_buffer = [] - made_progress = False - - result.reset_change_flag() + old_result = result while stmt_queue or stmt_queue_push_buffer: if not stmt_queue: @@ -529,7 +522,7 @@ class SymbolKindFinder(object): kim = make_kim(phase_name, check=False) for ident, _, _ in stmt.loops: - result.set(phase_name, ident, kind=Integer()) + result = result.set(phase_name, ident, kind=Integer()) if stmt.assignee_subscript: continue @@ -540,7 +533,7 @@ class SymbolKindFinder(object): stmt_queue_push_buffer.append((phase_name, stmt)) else: made_progress = True - result.set(phase_name, stmt.assignee, kind=kind) + result = result.set(phase_name, stmt.assignee, kind=kind) elif isinstance(stmt, lang.AssignFunctionCall): kim = make_kim(phase_name, check=False) @@ -554,7 +547,7 @@ class SymbolKindFinder(object): else: made_progress = True for assignee, kind in zip(stmt.assignees, kinds): - result.set(phase_name, assignee, kind=kind) + result = result.set(phase_name, assignee, kind=kind) elif isinstance(stmt, lang.AssignmentBase): raise TODO() @@ -563,7 +556,7 @@ class SymbolKindFinder(object): # We only care about assignments. pass - if not result.is_changed(): + if result == old_result: break # {{{ check consistency of obtained kinds @@ -610,7 +603,7 @@ def infer_kinds(dag, function_registry=None): :arg dag: a :class:`dagrt.language.DAGCode` :arg function_registry: if not *None*, the function registry to use - :returns: a :class:`SymbolKindTable` + :returns: an updated :class:`DAGCode` """ if function_registry is None: from dagrt.function_registry import base_function_registry @@ -620,7 +613,9 @@ def infer_kinds(dag, function_registry=None): names = list(dag.phases) phases = [phase.statements for phase in dag.phases.values()] - return kind_finder(names, phases) + sym_kind_table = kind_finder(names, phases, dag.sym_kind_table) + + return dag.copy(sym_kind_table=sym_kind_table) # }}} diff --git a/dagrt/language.py b/dagrt/language.py index 4603d7b..6ec8b85 100644 --- a/dagrt/language.py +++ b/dagrt/language.py @@ -649,22 +649,30 @@ class DAGCode(RecordWithoutPickling): .. attribute:: initial_phase the name of the starting phase + + .. attribute:: sym_kind_table + + a :class:`dagrt.data.SymbolKindTable` with symbol kind information """ @classmethod - def from_phases_list(cls, phases, initial_phase): + def from_phases_list(cls, phases, initial_phase, symbol_kinds=None): name_to_phase = dict() for phase in phases: if phase.name in name_to_phase: raise ValueError("duplicate phase name '%s'" % phase.name) name_to_phase[phase.name] = phase - return cls(name_to_phase, initial_phase) + return cls(name_to_phase, initial_phase, symbol_kinds) - def __init__(self, phases, initial_phase): + def __init__(self, phases, initial_phase, sym_kind_table=None): + if sym_kind_table is None: + from dagrt.data import SymbolKindTable + sym_kind_table = SymbolKindTable() assert not isinstance(phases, list) RecordWithoutPickling.__init__(self, phases=phases, - initial_phase=initial_phase) + initial_phase=initial_phase, + sym_kind_table=sym_kind_table) # {{{ identifier wrangling @@ -708,7 +716,6 @@ class DAGCode(RecordWithoutPickling): return "\n".join(lines) - # }}} diff --git a/setup.py b/setup.py index 706c6ed..2000b06 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ def main(): "pytest>=2.3", "mako", "six", + "pyrsistent", ], ) -- GitLab From 7e97e49d308a1ec2b46e683fc40c0660032c4a18 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Wed, 5 Feb 2020 23:40:58 -0600 Subject: [PATCH 2/9] Fix a flake8 error --- dagrt/data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dagrt/data.py b/dagrt/data.py index fe90024..436ae6a 100644 --- a/dagrt/data.py +++ b/dagrt/data.py @@ -473,6 +473,7 @@ class SymbolKindFinder(object): stmt_queue.extend((name, stmt) for stmt in phase) stmt_queue_push_buffer = [] + made_progress = False old_result = result while stmt_queue or stmt_queue_push_buffer: -- GitLab From fe48ad1ca9c128569618ecfc5017273b7f1b20b1 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Thu, 6 Feb 2020 02:28:34 -0600 Subject: [PATCH 3/9] DAGCode: Also print symbol kind table --- dagrt/language.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dagrt/language.py b/dagrt/language.py index 6ec8b85..fdf290e 100644 --- a/dagrt/language.py +++ b/dagrt/language.py @@ -700,6 +700,7 @@ class DAGCode(RecordWithoutPickling): def __str__(self): lines = [] + lines.append("===== Phases =====") for phase_name, phase in sorted(six.iteritems(self.phases)): phase_title = "PHASE \"%s\"" % phase_name if phase_name == self.initial_phase: @@ -714,6 +715,9 @@ class DAGCode(RecordWithoutPickling): lines.append(" -> (next phase) \"%s\"" % phase.next_phase) lines.append("") + lines.append("===== Known Symbol Kinds =====") + lines.append(str(self.sym_kind_table)) + return "\n".join(lines) # }}} -- GitLab From 028448d779ad1fe87676eb3caf1c6669b79838ed Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Thu, 6 Feb 2020 02:33:56 -0600 Subject: [PATCH 4/9] Fix argument name in constructor --- dagrt/language.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dagrt/language.py b/dagrt/language.py index fdf290e..1f47564 100644 --- a/dagrt/language.py +++ b/dagrt/language.py @@ -656,13 +656,13 @@ class DAGCode(RecordWithoutPickling): """ @classmethod - def from_phases_list(cls, phases, initial_phase, symbol_kinds=None): + def from_phases_list(cls, phases, initial_phase, sym_kind_table=None): name_to_phase = dict() for phase in phases: if phase.name in name_to_phase: raise ValueError("duplicate phase name '%s'" % phase.name) name_to_phase[phase.name] = phase - return cls(name_to_phase, initial_phase, symbol_kinds) + return cls(name_to_phase, initial_phase, sym_kind_table) def __init__(self, phases, initial_phase, sym_kind_table=None): if sym_kind_table is None: -- GitLab From 1cf2d30b9e5406a142835dd3db853bc435ece0a5 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Thu, 6 Feb 2020 02:36:32 -0600 Subject: [PATCH 5/9] make docs for SymbolKindTable methods visible --- dagrt/data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dagrt/data.py b/dagrt/data.py index 436ae6a..9af9a28 100644 --- a/dagrt/data.py +++ b/dagrt/data.py @@ -145,6 +145,9 @@ class SymbolKindTable(ImmutableRecord): a nested mapping ``[phase_name][symbol_name]`` to :class:`SymbolKind` instances + + .. automethod:: get + .. automethod:: set """ def __init__(self, global_table=None, per_phase_table=None): -- GitLab From 3ee0d1578922df188b8393ccb3a55498d5f6652f Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Fri, 7 Feb 2020 22:25:16 -0600 Subject: [PATCH 6/9] Fix kind inference for arrays with complex assignments (closes #36) Depends on !41 --- dagrt/data.py | 6 +++--- test/test_codegen.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/dagrt/data.py b/dagrt/data.py index 9af9a28..19c4bcc 100644 --- a/dagrt/data.py +++ b/dagrt/data.py @@ -528,15 +528,15 @@ class SymbolKindFinder(object): for ident, _, _ in stmt.loops: result = result.set(phase_name, ident, kind=Integer()) - if stmt.assignee_subscript: - continue - try: kind = kim(stmt.expression) except UnableToInferKind: stmt_queue_push_buffer.append((phase_name, stmt)) else: made_progress = True + if stmt.assignee_subscript is not None: + # Subscripted assignment => assigning to array + kind = unify(Array(is_real_valued=True), kind) result = result.set(phase_name, stmt.assignee, kind=kind) elif isinstance(stmt, lang.AssignFunctionCall): diff --git a/test/test_codegen.py b/test/test_codegen.py index 98d5b33..69ce20c 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -154,6 +154,31 @@ def test_KeyToUniqueNameMap(): assert map_with_prefix.get_or_make_name_for_key('a') == 'prefixa' +def test_array_kind_inference(): + from dagrt.language import CodeBuilder, DAGCode + from pymbolic import var + x = var("x") + y = var("y") + + with CodeBuilder("main") as cb: + cb(x, "array(1)") + cb(y, "array(1)") + cb(x[0], 1 + 1j) + cb(y[0], 1) + + code = DAGCode.from_phases_list([cb.as_execution_phase("main")], "main") + from dagrt.data import infer_kinds + code = infer_kinds(code) + + from dagrt.data import Array + + x_kind = code.sym_kind_table.per_phase_table["main"]["x"] + y_kind = code.sym_kind_table.per_phase_table["main"]["y"] + + assert x_kind == Array(is_real_valued=False), x_kind + assert y_kind == Array(is_real_valued=True), y_kind + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab From deff199ea13c90c37b6cdfd63d90cda9881e6366 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Fri, 7 Feb 2020 22:30:22 -0600 Subject: [PATCH 7/9] Fix condition --- dagrt/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagrt/data.py b/dagrt/data.py index 19c4bcc..70a5265 100644 --- a/dagrt/data.py +++ b/dagrt/data.py @@ -534,7 +534,7 @@ class SymbolKindFinder(object): stmt_queue_push_buffer.append((phase_name, stmt)) else: made_progress = True - if stmt.assignee_subscript is not None: + if stmt.assignee_subscript: # Subscripted assignment => assigning to array kind = unify(Array(is_real_valued=True), kind) result = result.set(phase_name, stmt.assignee, kind=kind) -- GitLab From 551209ea5693f2fe3b2aff902f9d31f855df1120 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Sun, 9 Feb 2020 16:48:13 -0600 Subject: [PATCH 8/9] Add some simple tests for infer_kinds() --- test/test_kind_inference.py | 113 ++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100755 test/test_kind_inference.py diff --git a/test/test_kind_inference.py b/test/test_kind_inference.py new file mode 100755 index 0000000..4b753e4 --- /dev/null +++ b/test/test_kind_inference.py @@ -0,0 +1,113 @@ +#! /usr/bin/env python +from __future__ import division, with_statement + +import sys + +from dagrt.expression import parse +from dagrt.language import CodeBuilder +from dagrt.data import infer_kinds, Scalar, Boolean, Integer, Array, UserType +from dagrt.function_registry import base_function_registry, register_ode_rhs + +from utils import create_DAGCode_with_steady_phase + + +__copyright__ = "Copyright (C) 2020 Matt Wala" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +def test_kind_inference_scalar(): + with CodeBuilder("main") as cb: + cb("x", 1.) + cb("y", 1.j) + + code = create_DAGCode_with_steady_phase(cb.statements) + code = infer_kinds(code) + + kinds = code.sym_kind_table.per_phase_table["main"] + + assert kinds["x"] == Scalar(is_real_valued=True) + assert kinds["y"] == Scalar(is_real_valued=False) + + +def test_kind_inference_boolean(): + with CodeBuilder("main") as cb: + cb("flag1", parse("1 > 2")) + cb("flag2", parse("not flag1")) + cb("flag3", parse("flag1 and flag2")) + cb("flag4", parse("flag1 or flag2")) + + code = create_DAGCode_with_steady_phase(cb.statements) + code = infer_kinds(code) + + kinds = code.sym_kind_table.per_phase_table["main"] + + assert kinds["flag1"] == Boolean() + assert kinds["flag2"] == Boolean() + assert kinds["flag3"] == Boolean() + assert kinds["flag4"] == Boolean() + + +def test_kind_inference_integer(): + with CodeBuilder("main") as cb: + cb("arr", "array(10)") + cb("arr[i]", 1., [("i", 0, 10)]) + + code = create_DAGCode_with_steady_phase(cb.statements) + code = infer_kinds(code) + + kinds = code.sym_kind_table.per_phase_table["main"] + assert kinds["i"] == Integer() + + +def test_kind_inference_array(): + with CodeBuilder("main") as cb: + cb("arr", "array(10)") + cb("arr", 1) + cb("arr_complex", "array(10)") + cb("arr_complex", 1j) + + code = create_DAGCode_with_steady_phase(cb.statements) + code = infer_kinds(code) + + kinds = code.sym_kind_table.per_phase_table["main"] + assert kinds["arr"] == Array(is_real_valued=True) + assert kinds["arr_complex"] == Array(is_real_valued=False) + + +def test_kind_inference_user_type(): + with CodeBuilder("main") as cb: + cb("y", parse("y +
* f(, y)")) + + code = create_DAGCode_with_steady_phase(cb.statements) + freg = register_ode_rhs(base_function_registry, "y", identifier="f") + code = infer_kinds(code, freg) + + kinds = code.sym_kind_table.global_table + assert kinds["y"] == UserType("y") + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) -- GitLab From 6a3e54e9eea82a40ef019d6ee5f3b36672183f25 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Sun, 9 Feb 2020 16:51:11 -0600 Subject: [PATCH 9/9] Move subscripts test to test_kind_inference.py --- test/test_codegen.py | 25 ------------------------- test/test_kind_inference.py | 21 +++++++++++++++++++++ 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/test/test_codegen.py b/test/test_codegen.py index 69ce20c..98d5b33 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -154,31 +154,6 @@ def test_KeyToUniqueNameMap(): assert map_with_prefix.get_or_make_name_for_key('a') == 'prefixa' -def test_array_kind_inference(): - from dagrt.language import CodeBuilder, DAGCode - from pymbolic import var - x = var("x") - y = var("y") - - with CodeBuilder("main") as cb: - cb(x, "array(1)") - cb(y, "array(1)") - cb(x[0], 1 + 1j) - cb(y[0], 1) - - code = DAGCode.from_phases_list([cb.as_execution_phase("main")], "main") - from dagrt.data import infer_kinds - code = infer_kinds(code) - - from dagrt.data import Array - - x_kind = code.sym_kind_table.per_phase_table["main"]["x"] - y_kind = code.sym_kind_table.per_phase_table["main"]["y"] - - assert x_kind == Array(is_real_valued=False), x_kind - assert y_kind == Array(is_real_valued=True), y_kind - - if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/test/test_kind_inference.py b/test/test_kind_inference.py index 4b753e4..a6a9634 100755 --- a/test/test_kind_inference.py +++ b/test/test_kind_inference.py @@ -3,6 +3,8 @@ from __future__ import division, with_statement import sys +from pymbolic import var + from dagrt.expression import parse from dagrt.language import CodeBuilder from dagrt.data import infer_kinds, Scalar, Boolean, Integer, Array, UserType @@ -93,6 +95,25 @@ def test_kind_inference_array(): assert kinds["arr_complex"] == Array(is_real_valued=False) +def test_kind_inference_array_subscripts(): + x = var("x") + y = var("y") + + with CodeBuilder("main") as cb: + cb(x, "array(1)") + cb(y, "array(1)") + cb(x[0], 1 + 1j) + cb(y[0], 1) + + code = create_DAGCode_with_steady_phase(cb.statements) + code = infer_kinds(code) + + kinds = code.sym_kind_table.per_phase_table["main"] + + assert kinds["x"] == Array(is_real_valued=False) + assert kinds["y"] == Array(is_real_valued=True) + + def test_kind_inference_user_type(): with CodeBuilder("main") as cb: cb("y", parse("y +
* f(, y)")) -- GitLab