diff --git a/examples/dagrt_fusion/fusion-study.py b/examples/dagrt_fusion/fusion-study.py index d37dca94ad93e539dd38609427c3f4798b613b2a..ce12f1691a5fbfa5c9dff7228a5d61784b1321f9 100644 --- a/examples/dagrt_fusion/fusion-study.py +++ b/examples/dagrt_fusion/fusion-study.py @@ -90,8 +90,7 @@ class GrudgeArgSubstitutor(gmap.SymbolicEvaluator): def map_grudge_variable(self, expr): if expr.name in self.args: return self.args[expr.name] - else: - return super().map_variable(expr) + return super().map_variable(expr) def transcribe_phase(dag, field_var_name, field_components, phase_name, @@ -337,66 +336,6 @@ class FusedRK4TimeStepper(RK4TimeStepperBase): yield (t, self.component_getter(output_states)) -class FusedGrudgeRK4TimeStepper(object): - - def __init__(self, queue, field_var_name, dt, fields, sym_rhs, discr, - component_getter, t_start=0): - self.t_start = t_start - self.dt = dt - dt_method = LSRK4Method(component_id=field_var_name) - dt_code = dt_method.generate() - - from pytools.obj_array import join_fields - - # Flatten fields. - flattened_fields = [] - for field in fields: - if isinstance(field, list): - flattened_fields.extend(field) - else: - flattened_fields.append(field) - flattened_fields = join_fields(*flattened_fields) - del fields - - output_vars, results, yielded_states = transcribe_phase( - dt_code, field_var_name, len(flattened_fields), - "primary", sym_rhs) - - output_t = results[0] - output_dt = results[1] - output_states = results[2] - output_residuals = results[3] - - assert len(output_states) == len(flattened_fields) - assert len(output_states) == len(output_residuals) - - flattened_results = join_fields(output_t, output_dt, *output_states) - self.bound_op = bind(discr, flattened_results) - self.queue = queue - - self.initial_context = { - "input_t": t_start, - "input_dt": dt, - self.state_name: flattened_fields, - "input_residual": flattened_fields, - } - - self.component_getter = component_getter - - def run(self, t_end): - t = self.t_start - context = self.initial_context.copy() - - while t <= t_end: - results = self.bound_op(self.queue, **context) - t = results[0] - context["input_t"] = t - context["input_dt"] = results[1] - output_states = results[2:] - context[self.state_name] = output_states - yield (t, self.component_getter(output_states)) - - def get_strong_wave_component(state_component): return (state_component[0], state_component[1:])