diff --git a/sumpy/expansion/multipole.py b/sumpy/expansion/multipole.py
index 2a6955e7c56589d6747e6ec855c08ff62adaa12f..5cf4b307874537e055c9633716d9ba314cb54623 100644
--- a/sumpy/expansion/multipole.py
+++ b/sumpy/expansion/multipole.py
@@ -28,7 +28,7 @@ from sumpy.expansion import (
     HelmholtzConformingVolumeTaylorExpansion,
     BiharmonicConformingVolumeTaylorExpansion)
 from sumpy.tools import mi_set_axis
-from pytools import cartesian_product, factorial
+from pytools import factorial
 
 import logging
 logger = logging.getLogger(__name__)
@@ -55,7 +55,7 @@ class VolumeTaylorMultipoleExpansionBase(MultipoleExpansionBase):
     """
 
     def coefficients_from_source(self, kernel, avec, bvec, rscale, sac=None):
-        from sumpy.kernel import DirectionalSourceDerivative
+        from sumpy.kernel import KernelWrapper
         if kernel is None:
             kernel = self.kernel
 
@@ -64,42 +64,11 @@ class VolumeTaylorMultipoleExpansionBase(MultipoleExpansionBase):
         if not self.use_rscale:
             rscale = 1
 
-        if isinstance(kernel, DirectionalSourceDerivative):
-            from sumpy.symbolic import make_sym_vector
-
-            dir_vecs = []
-            tmp_kernel = kernel
-            while isinstance(tmp_kernel, DirectionalSourceDerivative):
-                dir_vecs.append(make_sym_vector(tmp_kernel.dir_vec_name, kernel.dim))
-                tmp_kernel = tmp_kernel.inner_kernel
-
-            if kernel.get_base_kernel() is not tmp_kernel:
-                raise NotImplementedError("Unknown kernel wrapper.")
-
-            nderivs = len(dir_vecs)
-
-            coeff_identifiers = self.get_full_coefficient_identifiers()
-            result = [0] * len(coeff_identifiers)
-
-            for i, mi in enumerate(coeff_identifiers):
-                # One source derivative is the dot product of the gradient and
-                # directional vector.
-                # For multiple derivatives, gradient of gradients is taken.
-                # For eg: in 3D, 2 source derivatives gives 9 terms and
-                # cartesian_product below enumerates these 9 terms.
-                for deriv_terms in cartesian_product(*[range(kernel.dim)]*nderivs):
-                    prod = 1
-                    derivative_mi = list(mi)
-                    for nderivative, deriv_dim in enumerate(deriv_terms):
-                        prod *= -derivative_mi[deriv_dim]
-                        prod *= dir_vecs[nderivative][deriv_dim]
-                        derivative_mi[deriv_dim] -= 1
-                    if any(v < 0 for v in derivative_mi):
-                        continue
-                    result[i] += mi_power(avec, derivative_mi) * prod
-
-            for i, mi in enumerate(coeff_identifiers):
-                result[i] /= (mi_factorial(mi) * rscale ** sum(mi))
+        if isinstance(kernel, KernelWrapper):
+            result = [
+                    kernel.postprocess_at_source(mi_power(avec, mi), avec)
+                    / mi_factorial(mi) / rscale ** sum(mi)
+                    for mi in self.get_full_coefficient_identifiers()]
         else:
             avec = [sym.UnevaluatedExpr(a * rscale**-1) for a in avec]
 
diff --git a/sumpy/kernel.py b/sumpy/kernel.py
index 1820ff1c861a6037815bf65ff81687a6e11cf57b..6b40a6d82f67f3f3041ec61718d7b550caf382b0 100644
--- a/sumpy/kernel.py
+++ b/sumpy/kernel.py
@@ -59,8 +59,9 @@ of them in the process.
 
 .. autoclass:: DerivativeBase
 .. autoclass:: AxisTargetDerivative
-.. autoclass:: DirectionalTargetDerivative
+.. autoclass:: AxisSourceDerivative
 .. autoclass:: DirectionalSourceDerivative
+.. autoclass:: DirectionalTargetDerivative
 
 Transforming kernels
 --------------------
@@ -68,7 +69,9 @@ Transforming kernels
 .. autoclass:: KernelMapper
 .. autoclass:: KernelCombineMapper
 .. autoclass:: KernelIdentityMapper
+.. autoclass:: AxisSourceDerivativeRemover
 .. autoclass:: AxisTargetDerivativeRemover
+.. autoclass:: SourceDerivativeRemover
 .. autoclass:: TargetDerivativeRemover
 .. autoclass:: DerivativeCounter
 """
@@ -824,6 +827,36 @@ class DerivativeBase(KernelWrapper):
     pass
 
 
+class AxisSourceDerivative(DerivativeBase):
+    init_arg_names = ("axis", "inner_kernel")
+
+    def __init__(self, axis, inner_kernel):
+        KernelWrapper.__init__(self, inner_kernel)
+        self.axis = axis
+
+    def __getinitargs__(self):
+        return (self.axis, self.inner_kernel)
+
+    def __str__(self):
+        return "d/dy%d %s" % (self.axis, self.inner_kernel)
+
+    def __repr__(self):
+        return "AxisSourceDerivative(%d, %r)" % (self.axis, self.inner_kernel)
+
+    def postprocess_at_source(self, expr, avec):
+        expr = self.inner_kernel.postprocess_at_source(expr, avec)
+        return -expr.diff(avec[self.axis])
+
+    def replace_base_kernel(self, new_base_kernel):
+        return type(self)(self.axis,
+            self.inner_kernel.replace_base_kernel(new_base_kernel))
+
+    def replace_inner_kernel(self, new_inner_kernel):
+        return type(self)(self.axis, new_inner_kernel)
+
+    mapper_method = "map_axis_source_derivative"
+
+
 class AxisTargetDerivative(DerivativeBase):
     init_arg_names = ("axis", "inner_kernel")
 
@@ -1019,6 +1052,7 @@ class KernelCombineMapper(KernelMapper):
 
     map_directional_target_derivative = map_axis_target_derivative
     map_directional_source_derivative = map_axis_target_derivative
+    map_axis_source_derivative = map_axis_target_derivative
 
 
 class KernelIdentityMapper(KernelMapper):
@@ -1033,7 +1067,9 @@ class KernelIdentityMapper(KernelMapper):
     map_stresslet_kernel = map_expression_kernel
 
     def map_axis_target_derivative(self, kernel):
-        return AxisTargetDerivative(kernel.axis, self.rec(kernel.inner_kernel))
+        return type(kernel)(kernel.axis, self.rec(kernel.inner_kernel))
+
+    map_axis_source_derivative = map_axis_target_derivative
 
     def map_directional_target_derivative(self, kernel):
         return type(kernel)(
@@ -1043,6 +1079,11 @@ class KernelIdentityMapper(KernelMapper):
     map_directional_source_derivative = map_directional_target_derivative
 
 
+class AxisSourceDerivativeRemover(KernelIdentityMapper):
+    def map_axis_source_derivative(self, kernel):
+        return self.rec(kernel.inner_kernel)
+
+
 class AxisTargetDerivativeRemover(KernelIdentityMapper):
     def map_axis_target_derivative(self, kernel):
         return self.rec(kernel.inner_kernel)
@@ -1053,7 +1094,7 @@ class TargetDerivativeRemover(AxisTargetDerivativeRemover):
         return self.rec(kernel.inner_kernel)
 
 
-class SourceDerivativeRemover(KernelIdentityMapper):
+class SourceDerivativeRemover(AxisSourceDerivativeRemover):
     def map_directional_source_derivative(self, kernel):
         return self.rec(kernel.inner_kernel)
 
@@ -1077,6 +1118,7 @@ class DerivativeCounter(KernelCombineMapper):
 
     map_directional_target_derivative = map_axis_target_derivative
     map_directional_source_derivative = map_axis_target_derivative
+    map_axis_source_derivative = map_axis_target_derivative
 
 # }}}
 
diff --git a/sumpy/qbx.py b/sumpy/qbx.py
index 0c0438c7b6a3bc083cfedc27c28116b1e8e3ad8e..ce2ee3bc8e30c8eaba278e330b18abad85879d40 100644
--- a/sumpy/qbx.py
+++ b/sumpy/qbx.py
@@ -464,6 +464,7 @@ class LayerPotentialMatrixBlockGenerator(LayerPotentialBase):
 
 def find_jump_term(kernel, arg_provider):
     from sumpy.kernel import (
+            AxisSourceDerivative,
             AxisTargetDerivative,
             DirectionalSourceDerivative,
             DirectionalTargetDerivative,
@@ -479,6 +480,9 @@ def find_jump_term(kernel, arg_provider):
         elif isinstance(kernel, DirectionalTargetDerivative):
             tgt_derivatives.append(kernel.dir_vec_name)
             kernel = kernel.kernel
+        elif isinstance(kernel, AxisSourceDerivative):
+            src_derivatives.append(kernel.axis)
+            kernel = kernel.kernel
         elif isinstance(kernel, DirectionalSourceDerivative):
             src_derivatives.append(kernel.dir_vec_name)
             kernel = kernel.kernel
diff --git a/test/test_fmm.py b/test/test_fmm.py
index 637eb8b83e504c6da08321af17e08ac2d5709189..45e9bcd36a62c883402a27a716529e694b946da3 100644
--- a/test/test_fmm.py
+++ b/test/test_fmm.py
@@ -418,6 +418,74 @@ def test_sumpy_fmm_exclude_self(ctx_factory):
     assert np.isclose(rel_err, 0, atol=1e-7)
 
 
+def test_sumpy_axis_source_derivative(ctx_factory):
+    logging.basicConfig(level=logging.INFO)
+
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    nsources = 500
+    dtype = np.float64
+
+    from boxtree.tools import (
+            make_normal_particle_array as p_normal)
+
+    knl = LaplaceKernel(2)
+    local_expn_class = VolumeTaylorLocalExpansion
+    mpole_expn_class = VolumeTaylorMultipoleExpansion
+    order = 10
+
+    sources = p_normal(queue, nsources, knl.dim, dtype, seed=15)
+
+    from boxtree import TreeBuilder
+    tb = TreeBuilder(ctx)
+
+    tree, _ = tb(queue, sources,
+            max_particles_in_box=30, debug=True)
+
+    from boxtree.traversal import FMMTraversalBuilder
+    tbuild = FMMTraversalBuilder(ctx)
+    trav, _ = tbuild(queue, tree, debug=True)
+
+    from pyopencl.clrandom import PhiloxGenerator
+    rng = PhiloxGenerator(ctx, seed=12)
+    weights = rng.uniform(queue, nsources, dtype=np.float64)
+
+    target_to_source = np.arange(tree.ntargets, dtype=np.int32)
+    self_extra_kwargs = {"target_to_source": target_to_source}
+
+    from functools import partial
+
+    from sumpy.fmm import SumpyExpansionWranglerCodeContainer
+    from sumpy.kernel import AxisTargetDerivative, AxisSourceDerivative
+
+    pots = []
+    for tgt_knl, src_knl in [(AxisTargetDerivative(0, knl), knl),
+            (knl, AxisSourceDerivative(0, knl))]:
+
+        wcc = SumpyExpansionWranglerCodeContainer(
+                ctx,
+                partial(mpole_expn_class, knl),
+                partial(local_expn_class, knl),
+                target_kernels=[tgt_knl],
+                source_kernels=[src_knl],
+                exclude_self=True)
+
+        wrangler = wcc.get_wrangler(queue, tree, dtype,
+                fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order,
+                self_extra_kwargs=self_extra_kwargs)
+
+        from boxtree.fmm import drive_fmm
+
+        pot, = drive_fmm(trav, wrangler, (weights,))
+        pots.append(pot.get())
+
+    rel_err = la.norm(pots[0] + pots[1]) / la.norm(pots[0])
+    logger.info("order %d -> relative l2 error: %g" % (order, rel_err))
+
+    assert np.isclose(rel_err, 0, atol=1e-5)
+
+
 # You can test individual routines by typing
 # $ python test_fmm.py 'test_sumpy_fmm(cl.create_some_context)'