diff --git a/loopy/preprocess.py b/loopy/preprocess.py index a1420d7bc92b6430f4c26378c57da6d825e48efa..2bdf9a115744b4c657bbd9a1ac8345f484015eb6 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -313,13 +313,17 @@ class ExtraInameIndexInserter(IdentityMapper): self.var_to_new_inames = var_to_new_inames def map_subscript(self, expr): - res = IdentityMapper.map_subscript(self, expr) try: new_idx = self.var_to_new_inames[expr.aggregate.name] except KeyError: return IdentityMapper.map_subscript(self, expr) else: - return res.aggregate[res.index + new_idx] + index = expr.index + if not isinstance(index, tuple): + index = (index,) + index = tuple(self.rec(i) for i in index) + + return expr.aggregate[index + new_idx] def map_variable(self, expr): try: diff --git a/test/test_loopy.py b/test/test_loopy.py index e748e7cc9f7187be73e11dcfe2b05e4e4ca0ca59..0391cf96e5e471c26315e5c8d4dd8625b3119101 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -707,25 +707,23 @@ def test_ilp_write_race_detection_global(ctx_factory): -def test_ilp_write_race_detection_local(ctx_factory): +def test_ilp_write_race_avoidance_local(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel(ctx.devices[0], - "{[i,j]: 0<=i,j<16 }", + "{[i,j]: 0<=i<16 and 0<=j<17 }", [ "[i:l.0, j:ilp] <> a[i] = 5+i+j", ], []) - from loopy.check import WriteRaceConditionError - import pytest - with pytest.raises(WriteRaceConditionError): - list(lp.generate_loop_schedules(knl)) + for k in lp.generate_loop_schedules(knl): + assert k.temporary_variables["a"].shape == (16,17) -def test_ilp_write_race_detection_private(ctx_factory): +def test_ilp_write_race_avoidance_private(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel(ctx.devices[0], @@ -735,10 +733,8 @@ def test_ilp_write_race_detection_private(ctx_factory): ], []) - from loopy.check import WriteRaceConditionError - import pytest - with pytest.raises(WriteRaceConditionError): - list(lp.generate_loop_schedules(knl)) + for k in lp.generate_loop_schedules(knl): + assert k.temporary_variables["a"].shape == (16,) # }}}