diff --git a/examples/dagrt_fusion/fusion-study.py b/examples/dagrt_fusion/fusion-study.py
index 4f301c836c47cec69406a69e600a9d55fc3d864e..8dea5d27888aa03d75bbe1a3e7452028724be41d 100755
--- a/examples/dagrt_fusion/fusion-study.py
+++ b/examples/dagrt_fusion/fusion-study.py
@@ -168,11 +168,11 @@ def transcribe_phase(dag, field_var_name, field_components, phase_name,
                 raise NotImplementedError("lhs of statement %s is not a variable: %s"
                         % (stmt.id, stmt.lhs))
             ctx[stmt.lhs.name] = sym.cse(
-                DagrtToGrudgeRewriter(ctx)(stmt.rhs),
-                (
-                    stmt.lhs.name
-                    .replace("<", "")
-                    .replace(">", "")))
+                    DagrtToGrudgeRewriter(ctx)(stmt.rhs),
+                    (
+                        stmt.lhs.name
+                        .replace("<", "")
+                        .replace(">", "")))
 
         elif isinstance(stmt, lang.AssignFunctionCall):
             if stmt.function_id != rhs_name:
@@ -200,8 +200,11 @@ def transcribe_phase(dag, field_var_name, field_components, phase_name,
         elif isinstance(stmt, lang.YieldState):
             d2g = DagrtToGrudgeRewriter(ctx)
             yielded_states.append(
-                (stmt.time_id, d2g(stmt.time), stmt.component_id,
-                    d2g(stmt.expression)))
+                    (
+                        stmt.time_id,
+                        d2g(stmt.time),
+                        stmt.component_id,
+                        d2g(stmt.expression)))
 
         else:
             raise NotImplementedError("statement %s is of unsupported type ''%s'"
@@ -627,6 +630,32 @@ class ExecutionMapperWithMemOpCounting(ExecutionMapper):
 
 # {{{ mem op counter check
 
+def test_assignment_memory_model(ctx_factory):
+    cl_ctx = ctx_factory()
+    queue = cl.CommandQueue(cl_ctx)
+
+    _, discr = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=3)
+
+    # Assignment instruction
+    bound_op = bind(
+            discr,
+            sym.Variable("input0", sym.DD_VOLUME)
+            + sym.Variable("input1", sym.DD_VOLUME),
+            exec_mapper_factory=ExecutionMapperWithMemOpCounting)
+
+    input0 = discr.zeros(queue)
+    input1 = discr.zeros(queue)
+
+    result, profile_data = bound_op(
+            queue,
+            profile_data={},
+            input0=input0,
+            input1=input1)
+
+    assert profile_data["bytes_read"] == input0.nbytes + input1.nbytes
+    assert profile_data["bytes_written"] == result.nbytes
+
+
 @pytest.mark.parametrize("use_fusion", (True, False))
 def test_stepper_mem_ops(ctx_factory, use_fusion):
     cl_ctx = ctx_factory()
@@ -904,14 +933,14 @@ def problem_stats(order=3):
     _, dg_discr_2d = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=order)
     print("Number of 2D elements:", dg_discr_2d.mesh.nelements)
     vol_discr_2d = dg_discr_2d.discr_from_dd("vol")
-    dofs_2d = set(group.nunit_nodes for group in vol_discr_2d.groups)
+    dofs_2d = {group.nunit_nodes for group in vol_discr_2d.groups}
     from pytools import one
     print("Number of DOFs per 2D element:", one(dofs_2d))
 
     _, dg_discr_3d = get_strong_wave_op_with_discr(cl_ctx, dims=3, order=order)
     print("Number of 3D elements:", dg_discr_3d.mesh.nelements)
     vol_discr_3d = dg_discr_3d.discr_from_dd("vol")
-    dofs_3d = set(group.nunit_nodes for group in vol_discr_3d.groups)
+    dofs_3d = {group.nunit_nodes for group in vol_discr_3d.groups}
     from pytools import one
     print("Number of DOFs per 3D element:", one(dofs_3d))