diff --git a/loopy/precompute.py b/loopy/precompute.py
index 11b1396f15b4f9dc440ee75480a1d25fbc1e091a..ae973f98c1de87e2821575f3c65c03c989f696fc 100644
--- a/loopy/precompute.py
+++ b/loopy/precompute.py
@@ -1,4 +1,4 @@
-from __future__ import division, absolute_import
+from __future__ import division, absolute_import, print_function
 import six
 from six.moves import range, zip
 
@@ -25,6 +25,7 @@ THE SOFTWARE.
 """
 
 
+import islpy as isl
 from loopy.symbolic import (get_dependencies,
         RuleAwareIdentityMapper, RuleAwareSubstitutionMapper,
         SubstitutionRuleMappingContext)
@@ -450,21 +451,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     if precompute_inames is not None:
         preexisting_precompute_inames = (
                 set(precompute_inames) & kernel.all_inames())
-
-        if (
-                preexisting_precompute_inames
-                and
-                len(preexisting_precompute_inames) < len(precompute_inames)):
-            raise LoopyError("some (but not all) of the inames in "
-                    "precompute_inames already exist. existing: %s non-existing: %s"
-                    % (
-                        preexisting_precompute_inames,
-                        set(precompute_inames) - preexisting_precompute_inames))
-
-        precompute_inames_already_exist = bool(preexisting_precompute_inames)
-
     else:
-        precompute_inames_already_exist = False
+        preexisting_precompute_inames = set()
 
     # }}}
 
@@ -483,20 +471,22 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
             old_name = saxis
             name = "%s_%s" % (c_subst_name, old_name)
 
-        if precompute_inames is not None and i < len(precompute_inames):
+        if (precompute_inames is not None
+                and i < len(precompute_inames)
+                and precompute_inames[i]):
             name = precompute_inames[i]
             tag_lookup_saxis = name
-            if (not precompute_inames_already_exist
+            if (name not in preexisting_precompute_inames
                     and var_name_gen.is_name_conflicting(name)):
                 raise RuntimeError("new storage axis name '%s' "
                         "conflicts with existing name" % name)
-
-        if not precompute_inames_already_exist:
+        else:
             name = var_name_gen(name)
 
         storage_axis_names.append(name)
-        new_iname_to_tag[name] = storage_axis_to_tag.get(
-                tag_lookup_saxis, default_tag)
+        if name not in preexisting_precompute_inames:
+            new_iname_to_tag[name] = storage_axis_to_tag.get(
+                    tag_lookup_saxis, default_tag)
 
         prior_storage_axis_name_dict[name] = old_name
 
@@ -522,9 +512,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     if storage_axis_names:
         # {{{ find domain to be changed
 
-        change_inames = expanding_inames
-        if precompute_inames_already_exist:
-            change_inames = change_inames | preexisting_precompute_inames
+        change_inames = expanding_inames | preexisting_precompute_inames
 
         from loopy.kernel.tools import DomainChanger
         domch = DomainChanger(kernel, change_inames)
@@ -551,40 +539,105 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
             else:
                 del new_iname_to_tag[saxis]
 
-        if not precompute_inames_already_exist:
-            new_kernel_domains = domch.get_domains_with(
-                    abm.augment_domain_with_sweep(
-                        domch.domain, non1_storage_axis_names,
-                        boxify_sweep=fetch_bounding_box))
-        else:
-            check_domain = domch.domain
+                if saxis in preexisting_precompute_inames:
+                    raise LoopyError("precompute axis %d (1-based) was "
+                            "eliminated as "
+                            "having length 1 but also mapped to existing "
+                            "iname '%s'" % (i+1, saxis))
+
+        mod_domain = domch.domain
+
+        # {{{ modify the domain, taking into account preexisting inames
+
+        # inames may already exist in mod_domain, add them primed to start
+        primed_non1_saxis_names = [
+                iname+"'" for iname in non1_storage_axis_names]
 
-            # {{{ check the domain the preexisting inames' domain
+        mod_domain = abm.augment_domain_with_sweep(
+            domch.domain, primed_non1_saxis_names,
+            boxify_sweep=fetch_bounding_box)
 
-            # inames already exist in check_domain, add them primed
-            primed_non1_saxis_names = [
-                    iname+"'" for iname in non1_storage_axis_names]
+        check_domain = mod_domain
+
+        for i, saxis in enumerate(non1_storage_axis_names):
+            var_dict = mod_domain.get_var_dict(isl.dim_type.set)
+
+            if saxis in preexisting_precompute_inames:
+                # add equality constraint between existing and new variable
+
+                dt, dim_idx = var_dict[saxis]
+                saxis_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx)
+
+                dt, dim_idx = var_dict[primed_non1_saxis_names[i]]
+                new_var_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx)
+
+                mod_domain = mod_domain.add_constraint(
+                        isl.Constraint.inequality_from_aff(new_var_aff - saxis_aff))
+
+                # project out the new one
+                mod_domain = mod_domain.project_out(dt, dim_idx, 1)
+
+            else:
+                # remove the prime from the new variable
+                dt, dim_idx = var_dict[primed_non1_saxis_names[i]]
+                mod_domain = mod_domain.set_dim_name(dt, dim_idx, saxis)
 
-            check_domain = abm.augment_domain_with_sweep(
-                check_domain, primed_non1_saxis_names,
-                boxify_sweep=fetch_bounding_box)
+        # {{{ check that we got the desired domain
 
-            # project out the original copies
-            from loopy.isl_helpers import project_out
-            check_domain = project_out(check_domain, non1_storage_axis_names)
+        check_domain = check_domain.project_out_except(
+                primed_non1_saxis_names, [isl.dim_type.set])
 
-            for iname in non1_storage_axis_names:
-                var_dict = check_domain.get_var_dict()
-                dt, dim_idx = var_dict[iname+"'"]
-                check_domain = check_domain.set_dim_name(dt, dim_idx, iname)
+        mod_check_domain = mod_domain
 
-            if not (check_domain <= domch.domain and domch.domain <= check_domain):
-                raise LoopyError("domain of preexisting inames does not match "
-                        "domain needed for precompute")
+        # re-add the prime from the new variable
+        var_dict = mod_check_domain.get_var_dict(isl.dim_type.set)
 
-            # }}}
+        for saxis in non1_storage_axis_names:
+            dt, dim_idx = var_dict[saxis]
+            mod_check_domain = mod_check_domain.set_dim_name(dt, dim_idx, saxis+"'")
+
+        mod_check_domain = mod_check_domain.project_out_except(
+                primed_non1_saxis_names, [isl.dim_type.set])
+
+        mod_check_domain, check_domain = isl.align_two(
+                mod_check_domain, check_domain)
+
+        # The modified domain can't get bigger by adding constraints
+        assert mod_check_domain <= check_domain
+
+        if not check_domain <= mod_check_domain:
+            print(check_domain)
+            print(mod_check_domain)
+            raise LoopyError("domain of preexisting inames does not match "
+                    "domain needed for precompute")
+
+        # }}}
+
+        # {{{ check that we didn't shrink the original domain
+
+        # project out the new names from the modified domain
+        orig_domain_inames = list(domch.domain.get_var_dict(isl.dim_type.set))
+        mod_check_domain = mod_domain.project_out_except(
+                orig_domain_inames, [isl.dim_type.set])
+
+        check_domain = domch.domain
+
+        mod_check_domain, check_domain = isl.align_two(
+                mod_check_domain, check_domain)
+
+        # The modified domain can't get bigger by adding constraints
+        assert mod_check_domain <= check_domain
+
+        if not check_domain <= mod_check_domain:
+            print(check_domain)
+            print(mod_check_domain)
+            raise LoopyError("original domain got shrunk by applying the precompute")
+
+        # }}}
+
+        # }}}
 
-            new_kernel_domains = domch.get_domains_with(domch.domain)
+        new_kernel_domains = domch.get_domains_with(mod_domain)
 
     else:
         # leave kernel domains unchanged
diff --git a/test/test_fortran.py b/test/test_fortran.py
index 4117b80a27b243dee1db94b5a0bb2b83b2ec8d49..c31c370076b681cb0593f38b6a4d92479541b872 100644
--- a/test/test_fortran.py
+++ b/test/test_fortran.py
@@ -453,6 +453,49 @@ def test_parse_and_fuse_two_kernels():
     knl, = lp.parse_transformed_fortran(fortran_src)
 
 
+def test_precompute_some_exist(ctx_factory):
+    fortran_src = """
+        subroutine dgemm(m,n,l,a,b,c)
+          implicit none
+          real*8 a(m,l),b(l,n),c(m,n)
+          integer m,n,k,i,j,l
+
+          do j = 1,n
+            do i = 1,m
+              do k = 1,l
+                c(i,j) = c(i,j) + b(k,j)*a(i,k)
+              end do
+            end do
+          end do
+        end subroutine
+        """
+
+    knl, = lp.parse_fortran(fortran_src)
+
+    assert len(knl.domains) == 1
+
+    knl = lp.split_iname(knl, "i", 8,
+            outer_tag="g.0", inner_tag="l.1")
+    knl = lp.split_iname(knl, "j", 8,
+            outer_tag="g.1", inner_tag="l.0")
+    knl = lp.split_iname(knl, "k", 8)
+    knl = lp.assume(knl, "n mod 8 = 0")
+    knl = lp.assume(knl, "m mod 8 = 0")
+    knl = lp.assume(knl, "l mod 8 = 0")
+
+    knl = lp.extract_subst(knl, "a_acc", "a[i1,i2]", parameters="i1, i2")
+    knl = lp.extract_subst(knl, "b_acc", "b[i1,i2]", parameters="i1, i2")
+    knl = lp.precompute(knl, "a_acc", "k_inner,i_inner",
+            precompute_inames="ktemp,itemp")
+    knl = lp.precompute(knl, "b_acc", "j_inner,k_inner",
+            precompute_inames="itemp,k2temp")
+
+    ref_knl = knl
+
+    ctx = ctx_factory()
+    lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=128, m=128, l=128))
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])