diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 9534279d4050c6789d80a820370fb3586f8d8105..5f4f2f2a77b927e4a4352077ed94492249ef75a0 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -486,6 +486,7 @@ set_array_dim_names = (MovedFunctionDeprecationWrapper( # {{{ remove_unused_arguments +@iterate_over_kernels_if_given_program def remove_unused_arguments(knl): new_args = [] diff --git a/loopy/transform/diff.py b/loopy/transform/diff.py index d0edcfd7812685938fca6c12bf4c35fe47031c2e..54d06605a9ec4e65ba93a0a21d66b69bbe53bfa6 100644 --- a/loopy/transform/diff.py +++ b/loopy/transform/diff.py @@ -33,6 +33,7 @@ import loopy as lp from loopy.symbolic import RuleAwareIdentityMapper, SubstitutionRuleMappingContext from loopy.isl_helpers import make_slab from loopy.diagnostic import LoopyError +from loopy.kernel import LoopKernel # {{{ diff mapper @@ -370,6 +371,8 @@ def diff_kernel(knl, diff_outputs, by, diff_iname_prefix="diff_i", *diff_context.by_name*, or *None* if no dependency exists. """ + assert isinstance(knl, LoopKernel) + from loopy.kernel.creation import apply_single_writer_depencency_heuristic knl = apply_single_writer_depencency_heuristic(knl, warn_if_used=True) @@ -398,14 +401,7 @@ def diff_kernel(knl, diff_outputs, by, diff_iname_prefix="diff_i", # }}} - # Differentiation lead to addition of new functions to the kernel. - # For example differentiating `sin(x)` -> `cos(x)`. Hence we would need to - # scope `cos(x)`. - from loopy.kernel.creation import scope_functions - differentiated_scoped_kernel = scope_functions( - diff_context.get_new_kernel()) - - return differentiated_scoped_kernel, result + return diff_context.get_new_kernel(), result # }}} diff --git a/test/test_diff.py b/test/test_diff.py index b735ab17a716c84bfa52df7f73476b4c575cda0f..a7fd929875c70352dbb4fec90fb28fe4ddfec3a3 100644 --- a/test/test_diff.py +++ b/test/test_diff.py @@ -55,7 +55,7 @@ def test_diff(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) - knl = lp.make_kernel( + knl = lp.make_kernel_function( """{ [i,j]: 0<=i,j<n }""", """ <> a = 1/(1+sinh(x[i] + y[j])**2) @@ -66,6 +66,7 @@ def test_diff(ctx_factory): from loopy.transform.diff import diff_kernel dknl, diff_map = diff_kernel(knl, "z", "x") + dknl = lp.make_program_from_kernel(dknl) dknl = lp.remove_unused_arguments(dknl) dknl = lp.add_inames_to_insn(dknl, "diff_i0", "writes:a_dx or writes:a")