From 193ac6378a37ad8fb7382bb3113f4484cc757ff2 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Sat, 20 May 2017 18:26:05 -0500 Subject: [PATCH] _UniqueVarNameGenerator: Fix O(n**2) behavior by keeping track of the array prefixes as they are aded to the set. Closes #25 on gitlab Depends on inducer/pytools!2 on gitlab --- loopy/kernel/__init__.py | 51 +++++++++++++++++++++++++--------------- test/test_loopy.py | 14 +++++++++++ 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 324f7da1a..da15967c6 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -44,33 +44,46 @@ from loopy.diagnostic import CannotBranchDomainTree, LoopyError # {{{ unique var names -def _is_var_name_conflicting_with_longer(name_a, name_b): - # Array dimensions implemented as separate arrays generate - # names by appending '_s'. Make sure that no - # conflicts can arise from these names. +class _UniqueVarNameGenerator(UniqueNameGenerator): - # Only deal with the case of b longer than a. - if not name_b.startswith(name_a): - return False + def __init__(self, existing_names=set(), forced_prefix=""): + super(_UniqueVarNameGenerator, self).__init__(existing_names, forced_prefix) + array_prefix_pattern = re.compile("(.*)_s[0-9]+$") - return re.match("^%s_s[0-9]+" % re.escape(name_b), name_a) is not None + array_prefixes = set() + for name in existing_names: + match = array_prefix_pattern.match(name) + if match is None: + continue + array_prefixes.add(match.group(1)) -def _is_var_name_conflicting(name_a, name_b): - if name_a == name_b: - return True + self.array_prefixes = array_prefixes + self.array_prefix_pattern = array_prefix_pattern - return ( - _is_var_name_conflicting_with_longer(name_a, name_b) - or _is_var_name_conflicting_with_longer(name_b, name_a)) + def _name_added(self, name): + match = self.array_prefix_pattern.match(name) + if match is None: + return + self.array_prefixes.add(match.group(1)) -class _UniqueVarNameGenerator(UniqueNameGenerator): def is_name_conflicting(self, name): - from pytools import any - return any( - _is_var_name_conflicting(name, other_name) - for other_name in self.existing_names) + if name in self.existing_names: + return True + + # Array dimensions implemented as separate arrays generate + # names by appending '_s'. Make sure that no + # conflicts can arise from these names. + + if name in self.array_prefixes: + return True + + match = self.array_prefix_pattern.match(name) + if match is None: + return False + + return match.group(1) in self.existing_names # }}} diff --git a/test/test_loopy.py b/test/test_loopy.py index 4bb6a2726..77fd49e07 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2231,6 +2231,20 @@ def test_struct_assignment(ctx_factory): knl(queue, N=200) +def test_kernel_var_name_generator(): + knl = lp.make_kernel( + "{[i]: 0 <= i <= 10}", + """ + <>a = 0 + <>b_s0 = 0 + """) + + vng = knl.get_var_name_generator() + + assert vng("a_s0") != "a_s0" + assert vng("b") != "b" + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab