From 541978651f12cd6a943293a6f8f86cf4ebce377c Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sat, 11 Aug 2018 05:36:38 +0530
Subject: [PATCH] small changes in tests to pass test_diff

---
 loopy/transform/data.py |  1 +
 loopy/transform/diff.py | 12 ++++--------
 test/test_diff.py       |  3 ++-
 3 files changed, 7 insertions(+), 9 deletions(-)

diff --git a/loopy/transform/data.py b/loopy/transform/data.py
index 9534279d4..5f4f2f2a7 100644
--- a/loopy/transform/data.py
+++ b/loopy/transform/data.py
@@ -486,6 +486,7 @@ set_array_dim_names = (MovedFunctionDeprecationWrapper(
 
 # {{{ remove_unused_arguments
 
+@iterate_over_kernels_if_given_program
 def remove_unused_arguments(knl):
     new_args = []
 
diff --git a/loopy/transform/diff.py b/loopy/transform/diff.py
index d0edcfd78..54d06605a 100644
--- a/loopy/transform/diff.py
+++ b/loopy/transform/diff.py
@@ -33,6 +33,7 @@ import loopy as lp
 from loopy.symbolic import RuleAwareIdentityMapper, SubstitutionRuleMappingContext
 from loopy.isl_helpers import make_slab
 from loopy.diagnostic import LoopyError
+from loopy.kernel import LoopKernel
 
 
 # {{{ diff mapper
@@ -370,6 +371,8 @@ def diff_kernel(knl, diff_outputs, by, diff_iname_prefix="diff_i",
         *diff_context.by_name*, or *None* if no dependency exists.
     """
 
+    assert isinstance(knl, LoopKernel)
+
     from loopy.kernel.creation import apply_single_writer_depencency_heuristic
     knl = apply_single_writer_depencency_heuristic(knl, warn_if_used=True)
 
@@ -398,14 +401,7 @@ def diff_kernel(knl, diff_outputs, by, diff_iname_prefix="diff_i",
 
     # }}}
 
-    # Differentiation lead to addition of new functions to the kernel.
-    # For example differentiating `sin(x)` -> `cos(x)`. Hence we would need to
-    # scope `cos(x)`.
-    from loopy.kernel.creation import scope_functions
-    differentiated_scoped_kernel = scope_functions(
-            diff_context.get_new_kernel())
-
-    return differentiated_scoped_kernel, result
+    return diff_context.get_new_kernel(), result
 
 # }}}
 
diff --git a/test/test_diff.py b/test/test_diff.py
index b735ab17a..a7fd92987 100644
--- a/test/test_diff.py
+++ b/test/test_diff.py
@@ -55,7 +55,7 @@ def test_diff(ctx_factory):
     ctx = ctx_factory()
     queue = cl.CommandQueue(ctx)
 
-    knl = lp.make_kernel(
+    knl = lp.make_kernel_function(
          """{ [i,j]: 0<=i,j<n }""",
          """
          <> a = 1/(1+sinh(x[i] + y[j])**2)
@@ -66,6 +66,7 @@ def test_diff(ctx_factory):
 
     from loopy.transform.diff import diff_kernel
     dknl, diff_map = diff_kernel(knl, "z", "x")
+    dknl = lp.make_program_from_kernel(dknl)
     dknl = lp.remove_unused_arguments(dknl)
 
     dknl = lp.add_inames_to_insn(dknl, "diff_i0", "writes:a_dx or writes:a")
-- 
GitLab