diff --git a/examples/dagrt_fusion/fusion-study.py b/examples/dagrt_fusion/fusion-study.py index ab33eb9105c0cc4d441ccf597594fb86a6c34051..f130c2486cdc271f5935da2aa79a8072b9b45644 100644 --- a/examples/dagrt_fusion/fusion-study.py +++ b/examples/dagrt_fusion/fusion-study.py @@ -212,6 +212,10 @@ def transcribe_phase(dag, field_var_name, field_components, phase_name, class RK4TimeStepperBase(object): + def __init__(self, queue, component_getter): + self.queue = queue + self.component_getter = component_getter + def get_initial_context(self, fields, t_start, dt): from pytools.obj_array import join_fields @@ -308,6 +312,8 @@ class RK4TimeStepper(RK4TimeStepperBase): its components """ + super().__init__(queue, component_getter) + from pymbolic import var # Construct sym_rhs to have the effect of replacing the RHS calls in the @@ -351,10 +357,9 @@ class FusedRK4TimeStepper(RK4TimeStepperBase): def __init__(self, queue, discr, field_var_name, sym_rhs, num_fields, component_getter, exec_mapper_factory=ExecutionMapper): - self.queue = queue + super().__init__(queue, component_getter) self.set_up_stepper( discr, field_var_name, sym_rhs, num_fields, exec_mapper_factory) - self.component_getter = component_getter # }}} @@ -605,7 +610,7 @@ class MemOpCountingExecutionMapper(ExecutionMapper): # {{{ mem op counter check @pytest.mark.parametrize("use_fusion", (True, False)) -def test_stepper_mem_ops(use_fusion=True): +def test_stepper_mem_ops(use_fusion): cl_ctx = cl.create_some_context() queue = cl.CommandQueue(cl_ctx)