diff --git a/dagrt/codegen/python.py b/dagrt/codegen/python.py index 16e89b7fc62b16d971648a8beabd3056ffaeb3a3..4fc6f3c0ff28e913f16ec3d202c0fb62e46cd619 100644 --- a/dagrt/codegen/python.py +++ b/dagrt/codegen/python.py @@ -35,85 +35,26 @@ import six def pad_python(line, width): - line += ' ' * (width - 1 - len(line)) - line += '\\' + line += " " * (width - 1 - len(line)) + line += "\\" return line wrap_line = partial(wrap_line_base, pad_func=pad_python) -_inner_class_code = '''from collections import namedtuple - -class StateComputed(namedtuple("StateComputed", - ["t", "time_id", "component_id", "state_component"])): - """ - .. attribute:: t - .. attribute:: time_id - .. attribute:: component_id - - Identifier of the state component being returned. - - .. attribute:: state_component - """ - -class StepCompleted( - namedtuple("StepCompleted", - ["dt", "t", "current_phase", "next_phase"])): - """ - .. attribute:: dt - - Size of next time step. - - .. attribute:: t - - Approximate integrator time at end of step. - - .. attribute:: current_phase - .. attribute:: next_phase - """ - -class StepFailed(namedtuple("StepFailed", ["t"])): - """ - .. attribute:: t - - Floating point number. - """ - - -class TimeStepUnderflow(RuntimeError): - pass - - -class FailStepException(RuntimeError): - pass - - -class TransitionEvent(Exception): - - def __init__(self, next_phase): - self.next_phase = next_phase - - -class StepError(Exception): - def __init__(self, condition, message): - self.condition = condition - self.messagew = message - - Exception.__init__(self, "%s: %s" % (condition, message)) - - +_inner_class_code = """ class _function_symbol_container(object): pass -''' +""" class PythonClassEmitter(PythonEmitter): """Emits code for a Python class.""" - def __init__(self, class_name, superclass='object'): + def __init__(self, class_name, superclass="object"): super(PythonClassEmitter, self).__init__() - self('from __future__ import division, print_function') - self('class {cls}({superclass}):'.format(cls=class_name, + self("from __future__ import division, print_function") + self("class {cls}({superclass}):".format(cls=class_name, superclass=superclass)) self.indent() @@ -130,10 +71,10 @@ class PythonNameManager(object): """ def __init__(self): - self._local_map = KeyToUniqueNameMap(forced_prefix='local') - self._global_map = KeyToUniqueNameMap(forced_prefix='self.global_', - start={'': 'self.t', - '
': 'self.dt'}) + self._local_map = KeyToUniqueNameMap(forced_prefix="local") + self._global_map = KeyToUniqueNameMap(forced_prefix="self.global_", + start={"": "self.t", + "
": "self.dt"}) self.function_map = KeyToUniqueNameMap(forced_prefix="self._functions.") def name_global(self, name): @@ -142,7 +83,7 @@ class PythonNameManager(object): def clear_locals(self): del self._local_map - self._local_map = KeyToUniqueNameMap(forced_prefix='local') + self._local_map = KeyToUniqueNameMap(forced_prefix="local") def name_local(self, local): """Return the identifier for a local variable.""" @@ -180,6 +121,12 @@ class BareExpression(object): class CodeGenerator(StructuredCodeGenerator): + """Python code generator. + + Generates a class that follows the same interface as the interpreter (see + :mod:`dagrt.exec_numpy`). + + """ def __init__(self, class_name, class_preamble=None, function_registry=None): """ @@ -203,7 +150,7 @@ class CodeGenerator(StructuredCodeGenerator): self._name_manager = PythonNameManager() self._expr_mapper = PythonExpressionMapper( - self._name_manager, function_registry, numpy='self._numpy') + self._name_manager, function_registry, numpy="self._numpy") def __call__(self, dag): from dagrt.codegen.analysis import verify_code @@ -265,36 +212,36 @@ class CodeGenerator(StructuredCodeGenerator): for line in _inner_class_code.splitlines(): emit(line) - from inspect import getsourcefile - import dagrt.builtins_python as builtins - builtins_source_file = getsourcefile(builtins) + from inspect import getsourcelines - if builtins_source_file is None: - raise RuntimeError( - "source code for built-in functions cannot be located") - - with open(builtins_source_file) as srcf: - builtins_source = srcf.read() + import dagrt.events_python as events + events_source_lines, _ = getsourcelines(events) + for line in events_source_lines: + line = line.rstrip("\r\n") + emit(line) - for l in builtins_source.split("\n"): - if l.startswith("def builtin"): + import dagrt.builtins_python as builtins + builtins_source_lines, _ = getsourcelines(builtins) + for line in builtins_source_lines: + line = line.rstrip("\r\n") + if line.startswith("def builtin"): emit("@staticmethod") - emit(l.replace("builtin", "_builtin")) + emit(line.replace("builtin", "_builtin")) self._class_emitter.incorporate(emit) def _emit_constructor(self, dag): """Emit the constructor.""" - emit = PythonFunctionEmitter('__init__', ('self', 'function_map')) + emit = PythonFunctionEmitter("__init__", ("self", "function_map")) # Perform necessary imports. - emit('import numpy') - emit('self._numpy = numpy') + emit("import numpy") + emit("self._numpy = numpy") # Make function symbols available - emit('self._functions = self._function_symbol_container()') + emit("self._functions = self._function_symbol_container()") for function_id in self._name_manager.function_map: py_function_id = self._name_manager.name_function(function_id) - emit('{py_function_id} = function_map["{function_id}"]' + emit("{py_function_id} = function_map[\"{function_id}\"]" .format( py_function_id=py_function_id, function_id=function_id)) @@ -310,17 +257,17 @@ class CodeGenerator(StructuredCodeGenerator): def _emit_set_up(self, dag): """Emit the set_up() method.""" - emit = PythonFunctionEmitter('set_up', - ('self', 't_start', 'dt_start', 'context')) - emit('self.t = t_start') - emit('self.dt = dt_start') + emit = PythonFunctionEmitter("set_up", + ("self", "t_start", "dt_start", "context")) + emit("self.t = t_start") + emit("self.dt = dt_start") # Save all the context components. for component_id in self._name_manager.get_global_ids(): component = self._name_manager.name_global(component_id) - if not component_id.startswith(''): + if not component_id.startswith(""): continue component_id = component_id[7:] - emit('{component} = context.get("{component_id}")'.format( + emit("{component} = context.get(\"{component_id}\")".format( component=component, component_id=component_id)) emit("self.next_phase = "+repr(dag.initial_phase)) @@ -329,7 +276,8 @@ class CodeGenerator(StructuredCodeGenerator): self._class_emitter.incorporate(emit) def _emit_run(self): - emit = PythonFunctionEmitter('run', ('self', 't_end=None', 'max_steps=None')) + args = ("self", "t_end=None", "max_steps=None", "dt_next=None") + emit = PythonFunctionEmitter("run", args) emit(""" n_steps = 0 while True: @@ -339,20 +287,23 @@ class CodeGenerator(StructuredCodeGenerator): if max_steps is not None and n_steps >= max_steps: return - cur_phase = self.next_phase + cur_state = self.next_phase try: - for evt in self.run_single_step(): + for evt in self.run_single_step(dt_next): yield evt except self.FailStepException: - yield self.StepFailed(t=self.t) + dt_next = (yield self.StepFailed(t=self.t, dt=self.dt)) continue except self.TransitionEvent as evt: self.next_phase = evt.next_phase - yield self.StepCompleted(dt=self.dt, t=self.t, - current_phase=cur_phase, next_phase=self.next_phase) + dt_next = (yield self.StepCompleted( + dt=self.dt, + t=self.t, + current_phase=cur_state, + next_phase=self.next_phase)) n_steps += 1 """) @@ -360,9 +311,13 @@ class CodeGenerator(StructuredCodeGenerator): self._class_emitter.incorporate(emit) def _emit_run_single_step(self): - emit = PythonFunctionEmitter('run_single_step', ('self',)) + args = ("self", "dt_next=None") + emit = PythonFunctionEmitter("run_single_step", args) emit(""" + if dt_next is not None: + self.dt = dt_next + self.next_phase, phase_func = ( self.phase_transition_table[self.next_phase]) @@ -381,7 +336,7 @@ class CodeGenerator(StructuredCodeGenerator): return self._class_emitter.get() def emit_def_begin(self, name): - self._emitter = PythonFunctionEmitter('phase_' + name, ('self',)) + self._emitter = PythonFunctionEmitter("phase_" + name, ("self",)) self._name_manager.clear_locals() def emit_def_end(self): @@ -390,7 +345,7 @@ class CodeGenerator(StructuredCodeGenerator): del self._emitter def emit_if_begin(self, expr): - self._emit('if {expr}:'.format(expr=self._expr(expr))) + self._emit("if {expr}:".format(expr=self._expr(expr))) self._emitter.indent() def emit_if_end(self): @@ -398,11 +353,11 @@ class CodeGenerator(StructuredCodeGenerator): def emit_else_begin(self): self._emitter.dedent() - self._emit('else:') + self._emit("else:") self._emitter.indent() def emit_return(self): - self._emit('return') + self._emit("return") # Ensure that Python recognizes this method as a generator function by # adding a yield statement. Otherwise, calling methods that do not # yield any values may result in raising a naked StopIteration instead @@ -412,7 +367,7 @@ class CodeGenerator(StructuredCodeGenerator): # TODO: Python 3.3+ has "yield from ()" which results in slightly less # awkward syntax. if not self._has_yield_inst: - self._emit('yield') + self._emit("yield") # {{{ statements @@ -436,7 +391,7 @@ class CodeGenerator(StructuredCodeGenerator): subscript_code = "" self._emit( - '{name}{sub} = {expr}' + "{name}{sub} = {expr}" .format( name=self._name_manager[inst.assignee], sub=subscript_code, @@ -459,7 +414,7 @@ class CodeGenerator(StructuredCodeGenerator): from pymbolic import var self._emit( - '{assign_code}{expr}' + "{assign_code}{expr}" .format( assign_code=assign_code, expr=self._expr_mapper.map_generic_call( @@ -468,30 +423,30 @@ class CodeGenerator(StructuredCodeGenerator): inst.kw_parameters))) def emit_inst_YieldState(self, inst): - self._emit('yield self.StateComputed(t={t}, time_id={time_id}, ' - 'component_id={component_id}, ' - 'state_component={state_component})'.format( + self._emit("yield self.StateComputed(t={t}, time_id={time_id}, " + "component_id={component_id}, " + "state_component={state_component})".format( t=self._expr(inst.time), time_id=repr(inst.time_id), component_id=repr(inst.component_id), state_component=self._expr(inst.expression))) def emit_inst_Raise(self, inst): - self._emit('raise self.StepError({condition}, {message})'.format( + self._emit("raise self.StepError({condition}, {message})".format( condition=repr(inst.error_condition.__name__), message=repr(inst.error_message))) if not self._has_yield_inst: - self._emit('yield') + self._emit("yield") def emit_inst_FailStep(self, inst): - self._emit('raise self.FailStepException()') + self._emit("raise self.FailStepException()") if not self._has_yield_inst: - self._emit('yield') + self._emit("yield") def emit_inst_SwitchPhase(self, inst): - assert '\'' not in inst.next_phase - self._emit('raise self.TransitionEvent(\'' + inst.next_phase + '\')') + assert "\"" not in inst.next_phase + self._emit("raise self.TransitionEvent(\"" + inst.next_phase + "\")") if not self._has_yield_inst: - self._emit('yield') + self._emit("yield") # }}} diff --git a/dagrt/events_python.py b/dagrt/events_python.py new file mode 100644 index 0000000000000000000000000000000000000000..fae312d1c5e3b081c8d035dc0a4229cc55ea4f96 --- /dev/null +++ b/dagrt/events_python.py @@ -0,0 +1,133 @@ +"""Python runtime exceptions and events""" + +__copyright__ = "Copyright (C) 2019 Matt Wala, Andreas Kloeckner" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from collections import namedtuple + + +# {{{ exceptions + +class FailStepException(Exception): + """Raised by a :class:`dagrt.language.FailStep` statement. + """ + pass + + +class TransitionEvent(Exception): + """Raised by a :class:`dagrt.language.SwitchPhase` statement. + + .. attribute:: next_phase + + The name of the next phase to execute. + """ + + def __init__(self, next_phase): + self.next_phase = next_phase + + +class StepError(Exception): + """Raised by a :class:`dagrt.language.Raise` statement. + + .. attribute:: condition + + A string indicating the type of error condition, obtained from the + `__name__` attribute of the condition argument to the Raise statement. + + .. attribute:: message + + A string with a detailed error message. + """ + + def __init__(self, condition, message): + self.condition = condition + self.message = message + Exception.__init__(self, "%s: %s" % (condition, message)) + +# }}} + + +# {{{ events + +class StateComputed(namedtuple("StateComputed", + ["t", "time_id", "component_id", "state_component"])): + """Returns the value of a state component. + + .. attribute:: t + + Time associated with the component being returned. + + .. attribute:: time_id + + An optional string describing the significance of the time :data:`t`. + + .. attribute:: component_id + + Identifier of the state component being returned. + + .. attribute:: state_component + + The value of the state component. + """ + + +class StepCompleted( + namedtuple("StepCompleted", + ["dt", "t", "current_phase", "next_phase"])): + """Indicates a step completed execution without error. + + .. attribute:: dt + + Value of `
` at the end of the step. + + .. attribute:: t + + Value of `` at the end of the step. + + .. attribute:: current_phase + + Name of the current phase. + + .. attribute:: next_phase + + Name of the next phase. + """ + + +class StepFailed(namedtuple("StepFailed", ["t", "dt"])): + """Indicates a step failed to finish executing with a non-fatal error. + + .. attribute:: t + + Value of `` at the end of execution. + + .. attribute:: dt + + Value of `
` at the end of execution. This may not be the same as + the value of `
` that was used when the step failed, because the + integrator may have changed the value of `
` after failing to + advance to the next step. + """ + +# }}} + +# vim: fdm=marker diff --git a/dagrt/exec_numpy.py b/dagrt/exec_numpy.py index 20fba9b1f2b624e601d5cb3008a211c355cf1fb7..ca2ae5a9cb0398d28d9e562b5dbaf514568f2072 100644 --- a/dagrt/exec_numpy.py +++ b/dagrt/exec_numpy.py @@ -22,60 +22,51 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from collections import namedtuple +from dagrt.events_python import ( + FailStepException, TransitionEvent, StepError, + StateComputed, StepCompleted, StepFailed) + from dagrt.expression import EvaluationMapper + import six -class FailStepException(Exception): - pass +__doc__ = """ +Interpreter class +----------------- -class TransitionEvent(Exception): +.. autoclass:: NumpyInterpreter - def __init__(self, next_phase): - self.next_phase = next_phase +.. _interpreter_events: -# {{{ events returned from NumpyInterpreter.run() +Events returned during interpretation +------------------------------------- -class StateComputed(namedtuple("StateComputed", - ["t", "time_id", "component_id", "state_component"])): - """ - .. attribute:: t - .. attribute:: time_id - .. attribute:: component_id +.. autoclass:: StateComputed - Identifier of the state component being returned. +.. autoclass:: StepCompleted - .. attribute:: state_component - """ +.. autoclass:: StepFailed -class StepCompleted( - namedtuple("StepCompleted", - ["dt", "t", "current_state", "next_phase"])): - """ - .. attribute:: dt +.. _interpreter_exceptions: - Size of next time step. +Exceptions raised during interpretation +--------------------------------------- - .. attribute:: t +.. autoclass:: StepError - Approximate integrator time at end of step. - .. attribute:: current_state - .. attribute:: next_phase - """ +Low-level exceptions +-------------------- +.. autoclass:: TransitionEvent -class StepFailed(namedtuple("StepFailed", ["t"])): - """ - .. attribute:: t +.. autoclass:: FailStepException - Floating point number. - """ -# }}} +""" # {{{ interpreter @@ -86,18 +77,25 @@ class NumpyInterpreter(object): .. attribute:: next_phase + The name of the next phase to execute + .. automethod:: set_up .. automethod:: run .. automethod:: run_single_step """ + # These are here as class attributes for compatibility with the code + # generator interface. StateComputed = StateComputed StepCompleted = StepCompleted StepFailed = StepFailed + StepError = StepError + TransitionEvent = TransitionEvent + FailStepException = FailStepException def __init__(self, code, function_map): """ - :arg code: an instance of :class:`dagrt.DAGCode` + :arg code: an instance of :class:`dagrt.language.DAGCode` :arg function_map: a mapping from function identifiers to functions """ self.code = code @@ -117,7 +115,10 @@ class NumpyInterpreter(object): self.eval_mapper = EvaluationMapper(self.context, self.functions) def set_up(self, t_start, dt_start, context): - """ + """Initialize the time integration state. + + :arg t_start: initial value of `` + :arg dt_start: initial value of `
` :arg context: a dictionary mapping identifiers to their values """ @@ -128,8 +129,29 @@ class NumpyInterpreter(object): raise ValueError("state variables may not start with '<'") self.context[""+key] = val - def run(self, t_end=None, max_steps=None): - """Generates events.""" + def run(self, t_end=None, max_steps=None, dt_next=None): + """Execute the time integrator, generating a sequence of events. + + See :ref:`interpreter_events` and :ref:`interpreter_exceptions`. State + components are yielded using :class:`StateComputed`. The end of a step + is signalled by either a :class:`StepCompleted` or :class:`StepFailed` + event. + + :arg t_end: if not *None*, halts execution when, at the start of a + step, `` is at least this value + :arg max_steps: if not *None*, bounds the number of steps to take + :arg dt_next: if not *None*, overrides the value of `
` at the + beginning of the first step. This should only be used with + integrators that support variable step sizes. The value of `
` + may be overridden before subsequent steps using the + `send `_ + method of the generator. + + :raises StepError: when a :class:`dagrt.language.Raise` executes + + :returns: a generator yielding events for a sequence of steps + + """ # noqa n_steps = 0 while True: @@ -141,25 +163,52 @@ class NumpyInterpreter(object): cur_state = self.next_phase try: - for evt in self.run_single_step(): + for evt in self.run_single_step(dt_next): yield evt except FailStepException: - yield StepFailed(t=self.context[""]) + dt_next = (yield StepFailed( + t=self.context[""], + dt=self.context["
"])) continue except TransitionEvent as evt: self.next_phase = evt.next_phase - yield StepCompleted( + dt_next = (yield StepCompleted( dt=self.context["
"], t=self.context[""], - current_state=cur_state, - next_phase=self.next_phase) + current_phase=cur_state, + next_phase=self.next_phase)) n_steps += 1 - def run_single_step(self): + def run_single_step(self, dt_next=None): + """Low-level interface for running a single step. + + :meth:`run` provides a simpler interface. + + This returns a generator yielding a sequence of :class:`StateComputed` + events. :class:`StepCompleted` and :class:`StepFailed` are *not* + yielded to indicate the end of the step. Non-normal exit circumstances + are handled with exceptions. + + :arg dt_next: if not *None*, overrides the value of `
` at the + beginning of the step (only for use when the integrator supports + varying the step size). + + :raises StepError: when a :class:`dagrt.language.Raise` executes + :raises TransitionEvent: when a :class:`dagrt.language.SwitchPhase` + executes + :raises FailStepException: when a :class:`dagrt.language.FailStep` + executes + + :returns: a generator yielding events for a single step + + """ + if dt_next is not None: + self.context["
"] = dt_next + try: self.exec_controller.reset() cur_state = self.code.phases[self.next_phase] @@ -256,7 +305,7 @@ class NumpyInterpreter(object): self.context[assignee] = res def exec_Raise(self, stmt): - raise stmt.error_condition(stmt.error_message) + raise StepError(stmt.error_condition.__name__, stmt.error_message) def exec_FailStep(self, stmt): raise FailStepException() diff --git a/dagrt/version.py b/dagrt/version.py index 0d96d97a7541d7b87196329b1912c2f2808cd501..6f62bed21b8e655e05f84adad5649d80bda72799 100644 --- a/dagrt/version.py +++ b/dagrt/version.py @@ -1,2 +1,2 @@ -VERSION = (2019, 4) +VERSION = (2019, 5) VERSION_TEXT = ".".join(str(i) for i in VERSION) diff --git a/doc/reference.rst b/doc/reference.rst index 6d5e74daa7a967374689b57b3c3c50d46078c8c2..567f0757773c4b828892dfa8291c1fd3e6eaa53f 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -54,4 +54,3 @@ Utilities --------------------------------- .. automodule:: dagrt.exec_numpy - :members: diff --git a/examples/adaptive_rk.py b/examples/adaptive_rk.py index 27402e79d35e6eae5a8dd257b514895583cf548b..baf405bf329402ea4fbc9756af8b1a787d362998 100755 --- a/examples/adaptive_rk.py +++ b/examples/adaptive_rk.py @@ -50,10 +50,7 @@ def adaptive_rk_method(tol): dt = var("
") t = var("") dt_old = var("dt_old") - - # Helpers for expression fragments - def norm(val): - return var("norm_2")(val) + norm = var("norm_2") def dt_scaling(tol, err): # Return a suitable scaling factor for dt. @@ -79,6 +76,7 @@ def adaptive_rk_method(tol): with cb.else_(): cb(y, y_h) cb(t, t + dt_old) + cb.yield_state(y, "y", t, "update_y") return DAGCode.from_phases_list( [cb.as_execution_phase("adaptrk")], "adaptrk") @@ -101,15 +99,37 @@ def main(): tolerances = [1.0e-1, 1.0e-2, 1.0e-3, 1.0e-5] errors = [] + t_end = 10 for tol in tolerances: + # Create the method and generate code. method = adaptive_rk_method(tol) AdaptiveRK = codegen.get_class(method) solver = AdaptiveRK({"g": rhs}) solver.set_up(t_start=1.0, dt_start=0.1, context={"y": np.array([1., 3.])}) - for evt in solver.run(t_end=10.0): - final_time = evt.t - errors.append(np.abs(solver.global_state_y[0] - soln(final_time))) + + # Execute the method. + dt_next = None + final_y = final_t = None + runner = solver.run(t_end=t_end) + while True: + try: + event = runner.send(dt_next) + except StopIteration: + break + + if isinstance(event, solver.StateComputed): + final_y = event.state_component + final_t = event.t + + dt_next = None + if isinstance(event, (solver.StepCompleted, solver.StepFailed)): + if event.t + event.dt >= t_end: + # Override dt to ensure we don't overshoot t_end. + dt_next = t_end - event.t + + assert np.isclose(final_t, t_end) + errors.append(np.abs(final_y[0] - soln(t_end))) print("Tolerance\tError") print("-" * 25) diff --git a/test/test_codegen_python.py b/test/test_codegen_python.py index 57be981d5dd29ab2cba3d572739ce94768b579f4..0fcf942064264cd2b172ae5963eddb1df13d3048 100755 --- a/test/test_codegen_python.py +++ b/test/test_codegen_python.py @@ -139,7 +139,7 @@ def test_basic_assign_rhs_codegen(): assert isinstance(hist[2], method.StepCompleted) -def test_basic_raise_codegen(): +def test_basic_raise_codegen(python_method_impl): """Test code generation of the Raise statement.""" cbuild = RawCodeBuilder() @@ -151,9 +151,7 @@ def test_basic_raise_codegen(): code = create_DAGCode_with_init_and_main_phases( init_statements=[], main_statements=cbuild.statements) - codegen = PythonCodeGenerator(class_name="Method") - Method = codegen.get_class(code) # noqa - method = Method({}) + method = python_method_impl(code, function_map={}) method.set_up(t_start=0, dt_start=0, context={}) try: # initialization @@ -162,12 +160,10 @@ def test_basic_raise_codegen(): # first primary step for result in method.run_single_step(): assert False - except method.TimeStepUnderflow: - pass - except Method.StepError as e: + except method.StepError as e: assert e.condition == "TimeStepUnderflow" - except Exception as e: - assert False, e + else: + assert False def test_basic_fail_step_codegen(): @@ -455,6 +451,51 @@ def test_class_preamble(): assert events[0].state_component == 1 +def test_user_supplied_step_size(python_method_impl): + """Test supplying *
* via the *send()* method of the generator. + """ + dt_max = 1 + + with CodeBuilder(name="primary") as cb: + with cb.if_("
", ">", dt_max): + cb.fail_step() + cb("", " +
") + + code = create_DAGCode_with_steady_phase(cb.statements) + interp = python_method_impl(code, function_map={}) + interp.set_up(t_start=0, dt_start=0, context={}) + + step_sizes_to_take = [2, 0.5, 0.5, 2, 0.5] + + # Construct list of expected events. + expected_events = [] + time = 0 + for dt_next in step_sizes_to_take: + if dt_next > dt_max: + expected_events.append(("fail", time, dt_next)) + else: + time += dt_next + expected_events.append(("complete", time, dt_next)) + + # Get actual events. + events = [] + runner = interp.run(dt_next=step_sizes_to_take[0]) + for step_size in step_sizes_to_take: + event = runner.send(step_size if events else None) + assert isinstance(event, (interp.StepFailed, interp.StepCompleted)) + events.append(event) + + # Check actual events against expected. + assert len(events) == len(expected_events) + for event, expected_event in zip(events, expected_events): + if expected_event[0] == "fail": + assert isinstance(event, interp.StepFailed) + else: + assert isinstance(event, interp.StepCompleted) + assert event.t == expected_event[1] + assert event.dt == expected_event[2] + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])