diff --git a/leap/multistep/multirate/__init__.py b/leap/multistep/multirate/__init__.py index 7cb94cbba1f83aec94e64d39103fe3c1702534fc..6817380daef4f4654e382a2dbc9450e5e13b7636 100644 --- a/leap/multistep/multirate/__init__.py +++ b/leap/multistep/multirate/__init__.py @@ -358,27 +358,18 @@ class MultiRateMultiStepMethod(Method): # {{{ process intervals - intervals = sorted(rhs.interval + intervals = tuple(rhs.interval for component_rhss in self.rhss for rhs in component_rhss) - substep_counts = [] - for i in range(1, len(intervals)): - last_interval = intervals[i-1] - interval = intervals[i] + interval_gcd = gcd(intervals) + if interval_gcd != 1: + raise ValueError( + "integration intervals must be relatively prime: " + "found intervals %s with common factor %d" + % (intervals, interval_gcd)) - if interval % last_interval != 0: - raise ValueError( - "intervals are not integer multiples of each other: " - + ", ".join(str(intv) for intv in intervals)) - - substep_counts.append(interval // last_interval) - - if min(intervals) != 1: - raise ValueError("the smallest interval is not 1") - - self.intervals = intervals - self.substep_counts = substep_counts + self.nsubsteps = lcm(intervals) # }}} @@ -417,10 +408,6 @@ class MultiRateMultiStepMethod(Method): # }}} - @property - def nsubsteps(self): - return max(self.intervals) - def emit_initialization(self, cb): """Initialize method variables.""" @@ -1398,4 +1385,30 @@ class TextualSchemeExplainer(SchemeExplainerBase): # }}} + +# {{{ utils + +def _gcd(a, b): + while a: + b, a = a, b % a + return b + + +def gcd(args): + args = iter(args) + result = next(args) + for arg in args: + result = _gcd(result, arg) + return result + + +def lcm(args): + args = iter(args) + result = next(args) + for arg in args: + result = (result * arg) // _gcd(result, arg) + return result + +# }}} + # vim: foldmethod=marker diff --git a/test/test_multirate.py b/test/test_multirate.py index 4d8f53bef5f79d07817280436e51dcf895eff2d1..5eed940262687138b5baa4ed51bfcadd139fb638 100644 --- a/test/test_multirate.py +++ b/test/test_multirate.py @@ -391,6 +391,101 @@ def test_dot(order=3, step_ratio=3, method_name="F", show=False): show_dependency_graph(code) +@pytest.mark.parametrize( + "fast_interval, slow_interval", + ( + (1, 2), + (3, 4))) +def test_two_rate_intervals(fast_interval, slow_interval, order=3): + # Solve + # f' = f+s + # s' = -f+s + + def true_f(t): + return np.exp(t)*np.sin(t) + + def true_s(t): + return np.exp(t)*np.cos(t) + + method = MultiRateMultiStepMethod( + order, + ( + ( + "dt", "fast", "=", + MRHistory(fast_interval, "f", ("fast", "slow",)), + ), + ( + "dt", "slow", "=", + MRHistory(slow_interval, "s", ("fast", "slow")) + ), + ), + static_dt=True) + + code = method.generate() + print(code) + + from pytools.convergence import EOCRecorder + eocrec = EOCRecorder() + + from dagrt.codegen import PythonCodeGenerator + codegen = PythonCodeGenerator(class_name='Method') + + stepper_cls = codegen.get_class(code) + + for n in range(4, 7): + t = 0 + dt = fast_interval * slow_interval * 2**(-n) + final_t = 10 + + stepper = stepper_cls( + function_map={ + "f": lambda t, fast, slow: fast + slow, + "s": lambda t, fast, slow: -fast + slow, + }) + + stepper.set_up( + t_start=t, dt_start=dt, + context={ + "fast": true_f(t), + "slow": true_s(t), + }) + + f_times = [] + f_values = [] + s_times = [] + s_values = [] + for event in stepper.run(t_end=final_t): + if isinstance(event, stepper_cls.StateComputed): + if event.component_id == "fast": + f_times.append(event.t) + f_values.append(event.state_component) + elif event.component_id == "slow": + s_times.append(event.t) + s_values.append(event.state_component) + else: + assert False, event.component_id + + f_times = np.array(f_times) + s_times = np.array(s_times) + f_values_true = true_f(f_times) + s_values_true = true_s(s_times) + + f_err = f_values - f_values_true + s_err = s_values - s_values_true + + error = ( + la.norm(f_err) / la.norm(f_values_true) + + # noqa: W504 + la.norm(s_err) / la.norm(s_values_true)) + + eocrec.add_data_point(dt, error) + + print(eocrec.pretty_print()) + + orderest = eocrec.estimate_order_of_convergence()[0, 1] + assert orderest > order * 0.7 + + def test_dependent_state(order=3, step_ratio=3): # Solve # f' = f+s