diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py
index 92fb232e92a6e891efcef7022e0dceedcb4f692f..34dc9e5c723b3979386c090bbedb98ec2bc9bc88 100644
--- a/loopy/kernel/array.py
+++ b/loopy/kernel/array.py
@@ -732,6 +732,23 @@ class ArrayBase(Record):
+    def __eq__(self, other):
+        from loopy.symbolic import (
+                is_tuple_of_expressions_equal as istoee,
+                is_expression_equal as isee)
+        return (
+                type(self) == type(other)
+                and self.name == other.name
+                and self.picklable_dtype == other.picklable_dtype
+                and istoee(self.shape, other.shape)
+                and self.dim_tags == other.dim_tags
+                and isee(self.offset, other.offset)
+                and self.order == other.order
+                )
+    def __ne__(self, other):
+        return not self.__eq__(other)
     def dtype(self):
         from loopy.tools import PicklableDtype
diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py
index c5cecfde2fa4005669d1fca5f3439ca282f2c3c0..647eb5e36c2d05748c4882cb8b3e3d3d84d981ae 100644
--- a/loopy/kernel/data.py
+++ b/loopy/kernel/data.py
@@ -376,6 +376,14 @@ class TemporaryVariable(ArrayBase):
     def __str__(self):
         return self.stringify(include_typename=False)
+    def __eq__(self, other):
+        return (
+                super(TemporaryVariable, self).__eq__(other)
+                and self.storage_shape == other.storage_shape
+                and self.base_indices == other.base_indices
+                and self.is_local == other.is_local
+                and self.base_storage == other.base_storage)
     def update_persistent_hash(self, key_hash, key_builder):
         """Custom hash computation function for use with
diff --git a/loopy/statistics.py b/loopy/statistics.py
index 834f482072a51386460e09d6c4f4d6a4406fa56a..7a064520b0d7da152f4d152568b83b68437961bd 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -605,4 +605,3 @@ def get_barrier_poly(knl):
                 barrier_poly += isl.PwQPolynomial('{ 1 }')
     return barrier_poly
diff --git a/loopy/subst.py b/loopy/subst.py
index a29e950a1f32d660eb10147c8638612078e816aa..5211be2d9a7e2163f87dbf27b131497ef485e0db 100644
--- a/loopy/subst.py
+++ b/loopy/subst.py
@@ -276,6 +276,8 @@ class AssignmentToSubstChanger(RuleAwareIdentityMapper):
         from pymbolic import var
         if index is None:
             return var(subst_name)
+        elif not isinstance(index, tuple):
+            return var(subst_name)(index)
             return var(subst_name)(*index)
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index d65440d57005183b2756040473080337257c5b52..d7a49fb22508c27d0a827189204480ab61cc7132 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -1186,4 +1186,47 @@ class AccessRangeMapper(WalkMapper):
 # }}}
+# {{{ is_expression_equal
+def is_expression_equal(a, b):
+    if a == b:
+        return True
+    from pymbolic.primitives import Expression
+    if isinstance(a, Expression) or isinstance(b, Expression):
+        if a is None or b is None:
+            return False
+        maybe_zero = a - b
+        from pymbolic import distribute
+        d_result = distribute(maybe_zero)
+        return d_result == 0
+    else:
+        return False
+def is_tuple_of_expressions_equal(a, b):
+    if a is None or b is None:
+        if a is None and b is None:
+            return True
+        return False
+    if not isinstance(a, tuple):
+        a = (a,)
+    if not isinstance(b, tuple):
+        b = (b,)
+    if len(a) != len(b):
+        return False
+    return all(
+        is_expression_equal(ai, bi)
+        for ai, bi in zip(a, b))
+# }}}
 # vim: foldmethod=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index c8072108032e82517773709f5f5fd257928d3bd9..0142fd6d4962da8c555412ee43a2a349c3e6d521 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -29,6 +29,7 @@ import sys
 import numpy as np
 import loopy as lp
 import pyopencl as cl
+import pyopencl.clmath  # noqa
 import pyopencl.clrandom  # noqa
 import pytest
@@ -2228,6 +2229,67 @@ def test_indexof_vec(ctx_factory):
     #assert np.array_equal(out.ravel(order="C"), np.arange(25))
+def test_finite_difference_expr_subst(ctx_factory):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+    grid = np.linspace(0, 2*np.pi, 2048, endpoint=False)
+    h = grid[1] - grid[0]
+    u = cl.clmath.sin(cl.array.to_device(queue, grid))
+    fin_diff_knl = lp.make_kernel(
+        "{[i]: 1<=i<=n}",
+        "out[i] = -(f[i+1] - f[i-1])/h",
+        [lp.GlobalArg("out", shape="n+2"), "..."])
+    flux_knl = lp.make_kernel(
+        "{[j]: 1<=j<=n}",
+        "f[j] = u[j]**2/2",
+        [
+            lp.GlobalArg("f", shape="n+2"),
+            lp.GlobalArg("u", shape="n+2"),
+            ])
+    fused_knl = lp.fuse_kernels([fin_diff_knl, flux_knl])
+    fused_knl = lp.set_options(fused_knl, write_cl=True)
+    evt, _ = fused_knl(queue, u=u, h=np.float32(1e-1))
+    fused_knl = lp.assignment_to_subst(fused_knl, "f")
+    fused_knl = lp.set_options(fused_knl, write_cl=True)
+    # This is the real test here: The automatically generated
+    # shape expressions are '2+n' and the ones above are 'n+2'.
+    # Is loopy smart enough to understand that these are equal?
+    evt, _ = fused_knl(queue, u=u, h=np.float32(1e-1))
+    fused0_knl = lp.affine_map_inames(fused_knl, "i", "inew", "inew+1=i")
+    gpu_knl = lp.split_iname(
+            fused0_knl, "inew", 128, outer_tag="g.0", inner_tag="l.0")
+    precomp_knl = lp.precompute(
+            gpu_knl, "f_subst", "inew_inner", fetch_bounding_box=True)
+    precomp_knl = lp.tag_inames(precomp_knl, {"j_0_outer": "unr"})
+    precomp_knl = lp.set_options(precomp_knl, return_dict=True)
+    evt, _ = precomp_knl(queue, u=u, h=h)
+def test_is_expression_equal():
+    from loopy.symbolic import is_expression_equal
+    from pymbolic import var
+    x = var("x")
+    y = var("y")
+    assert is_expression_equal(x+2, 2+x)
+    assert is_expression_equal((x+2)**2, x**2 + 4*x + 4)
+    assert is_expression_equal((x+y)**2, x**2 + 2*x*y + y**2)
 if __name__ == "__main__":
     if len(sys.argv) > 1: