diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 324f7da1a21de0115ea060ff7ef55e52ab0913d4..e8c846fbc491b7049d7820e3ef14d9ed8071ded3 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -44,33 +44,49 @@ from loopy.diagnostic import CannotBranchDomainTree, LoopyError # {{{ unique var names -def _is_var_name_conflicting_with_longer(name_a, name_b): - # Array dimensions implemented as separate arrays generate - # names by appending '_s'. Make sure that no - # conflicts can arise from these names. +class _UniqueVarNameGenerator(UniqueNameGenerator): - # Only deal with the case of b longer than a. - if not name_b.startswith(name_a): - return False + def __init__(self, existing_names=set(), forced_prefix=""): + super(_UniqueVarNameGenerator, self).__init__(existing_names, forced_prefix) + array_prefix_pattern = re.compile("(.*)_s[0-9]+$") - return re.match("^%s_s[0-9]+" % re.escape(name_b), name_a) is not None + array_prefixes = set() + for name in existing_names: + match = array_prefix_pattern.match(name) + if match is None: + continue + array_prefixes.add(match.group(1)) -def _is_var_name_conflicting(name_a, name_b): - if name_a == name_b: - return True + self.conflicting_array_prefixes = array_prefixes + self.array_prefix_pattern = array_prefix_pattern - return ( - _is_var_name_conflicting_with_longer(name_a, name_b) - or _is_var_name_conflicting_with_longer(name_b, name_a)) + def _name_added(self, name): + match = self.array_prefix_pattern.match(name) + if match is None: + return + self.conflicting_array_prefixes.add(match.group(1)) -class _UniqueVarNameGenerator(UniqueNameGenerator): def is_name_conflicting(self, name): - from pytools import any - return any( - _is_var_name_conflicting(name, other_name) - for other_name in self.existing_names) + if name in self.existing_names: + return True + + # Array dimensions implemented as separate arrays generate + # names by appending '_s'. Make sure that no + # conflicts can arise from these names. + + # Case 1: a_s0 is already a name; we are trying to insert a + # Case 2: a is already a name; we are trying to insert a_s0 + + if name in self.conflicting_array_prefixes: + return True + + match = self.array_prefix_pattern.match(name) + if match is None: + return False + + return match.group(1) in self.existing_names # }}} diff --git a/setup.py b/setup.py index a941eecd2b58daf413830fc22500179d3e8a8cf1..150cb1cc4bc6ee13a7d516ab09c8824d76a2c6c9 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ setup(name="loo.py", ], install_requires=[ - "pytools>=2016.2.6", + "pytools>=2017.1", "pymbolic>=2016.2", "genpy>=2016.1.2", "cgen>=2016.1", diff --git a/test/test_loopy.py b/test/test_loopy.py index 4bb6a27267bd7b1880265bdd5b47ee676a480fb3..77fd49e07b8d884fd12324e735770b8d5a488b48 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2231,6 +2231,20 @@ def test_struct_assignment(ctx_factory): knl(queue, N=200) +def test_kernel_var_name_generator(): + knl = lp.make_kernel( + "{[i]: 0 <= i <= 10}", + """ + <>a = 0 + <>b_s0 = 0 + """) + + vng = knl.get_var_name_generator() + + assert vng("a_s0") != "a_s0" + assert vng("b") != "b" + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])