From cd6c4ee94168c06559117dcdfad256126b5738a3 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 3 Mar 2022 16:30:26 -0600 Subject: [PATCH] Fix Named tag to actually prevent name clashes --- pytato/target/loopy/codegen.py | 5 +++-- test/test_codegen.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 1bf286e..8055c12 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -746,8 +746,9 @@ def _generate_name_for_temp(expr: Array, state: CodeGenState) -> str: name_tag, = expr.tags_of_type(Named) if state.var_name_gen.is_name_conflicting(name_tag.name): raise ValueError(f"Cannot assign the name {name_tag.name} to the" - f" temporary corresponding to {expr} as it is" - " referring a loopy kernel argument.") + f" temporary corresponding to {expr} as it " + "conflicts with an existing name. ") + state.var_name_gen.add_name(name_tag.name) return name_tag.name elif expr.tags_of_type(PrefixNamed): prefix_tag, = expr.tags_of_type(PrefixNamed) diff --git a/test/test_codegen.py b/test/test_codegen.py index 2e619f7..6a2080a 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -59,6 +59,18 @@ def test_basic_codegen(ctx_factory): assert (out == x_in * x_in).all() +def test_named_clash(ctx_factory): + x = pt.make_placeholder("x", (5,), np.int64) + + from pytato.tags import ImplStored, Named + expr = ( + (2*x).tagged((Named("xx"), ImplStored())) + + (3*x).tagged((Named("xx"), ImplStored()))) + + with pytest.raises(ValueError): + pt.generate_loopy(expr) + + def test_scalar_placeholder(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) -- GitLab