From 4870c19050a0ba29a8432b569d397baa7317b944 Mon Sep 17 00:00:00 2001 From: James Stevens <jdsteve2@illinois.edu> Date: Fri, 16 Oct 2015 20:02:59 -0500 Subject: [PATCH] reg usage estimator in somewhat usable state --- loopy/statistics.py | 149 +++++++++++++++++++--------------------- test/test_statistics.py | 5 +- 2 files changed, 72 insertions(+), 82 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 48fda3e04..d81050c65 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -550,23 +550,27 @@ class RegisterUsageEstimator(CombineMapper): return 0 #''' def map_variable(self, expr): + name = expr.name if expr in self.vars_found: return 0 - else: - self.vars_found.append(expr) - print("new var found: ", expr) - print("knl.temp_vars: \n", self.knl.temporary_variables) - print("found in temp_vars? ", expr.name in self.knl.temporary_variables) - print("found in inames? ", expr.name in self.knl.all_inames) - #print("knl.vars: \n", self.knl.variables) - if expr.name in self.knl.temporary_variables: - print("local? ", self.knl.temporary_variables[expr.name].is_local) - - #print("local? ", self.knl.temporary_variables[expr.name].is_local) - if "_dim_" in str(expr): #TODO how to remove block/thread size/id vars? + + self.vars_found.append(expr) + if name in self.knl.temporary_variables: + if self.knl.temporary_variables[name].is_local: + print("found temp var with local tag, not counting: ", expr) #TODO remove after debug return 0 else: return 1 + elif name in self.knl.all_inames(): + from loopy.kernel.data import AxisTag + if (self.knl.iname_to_tag.get(name) is None or + not isinstance(self.knl.iname_to_tag.get(name), AxisTag)): + #TODO use more specific positive instead of negative + return 1 + else: + return 0 + else: + return 1 map_tagged_variable = map_variable @@ -917,6 +921,57 @@ def get_regs_per_thread(knl): return get_regs_per_thread3_2(knl) +def get_regs_per_thread3_2(knl): + + """Estimate registers per thread usage by a loopy kernel. + + :parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated. + + """ + + from loopy.preprocess import preprocess_kernel, infer_unknown_types + from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction + from operator import mul + knl = infer_unknown_types(knl, expect_completion=True) + knl = preprocess_kernel(knl) + knl = lp.get_one_scheduled_kernel(knl) + max_regs = 0 + block_reg_totals = [0] + reg_counters = [RegisterUsageEstimator(knl)] + # multiple counters to track nested sets of previously used iname+index combinations + + for sched_item in knl.schedule: + if isinstance(sched_item, EnterLoop): + if sched_item.iname: # (if not empty) + block_reg_totals.append(0) + # start a new estimator + reg_counters.append(RegisterUsageEstimator(knl)) + else: + print("Error, how does this happen?") #TODO + 1/0 + + elif isinstance(sched_item, LeaveLoop): + if sched_item.iname: # (if not empty) + if block_reg_totals[-1] > max_regs: + max_regs = block_reg_totals[-1] + # pop to resume previous total + block_reg_totals.pop() + reg_counters.pop() + else: + print("Error, how does this happen?") #TODO + 1/0 + elif isinstance(sched_item, RunInstruction): + insn = knl.id_to_insn[sched_item.insn_id] + block_reg_totals[-1] += reg_counters[-1](insn.assignee) + \ + reg_counters[-1](insn.expression) + + # finished looping, check outer block + if block_reg_totals[-1] > max_regs: + max_regs = block_reg_totals[-1] + + return max_regs + +''' # map_var and map_tagged_var returned 1, no checking for any duplication def get_regs_per_thread1(knl): @@ -1084,74 +1139,9 @@ def get_regs_per_thread3(knl): #print("final, totals: \n", block_reg_totals, max_regs) return max_regs +''' -def get_regs_per_thread3_2(knl): - - """Estimate registers per thread usage by a loopy kernel. - - :parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated. - - """ - - from loopy.preprocess import preprocess_kernel, infer_unknown_types - from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction - from operator import mul - knl = infer_unknown_types(knl, expect_completion=True) - knl = preprocess_kernel(knl) - knl = lp.get_one_scheduled_kernel(knl) - #print(knl) - max_regs = 0 - #current_loop_indices = 0 - block_reg_totals = [0] - reg_counters = [RegisterUsageEstimator(knl)] - # multiple counters to track nested sets of previously used iname+index combinations - - for sched_item in knl.schedule: - if isinstance(sched_item, EnterLoop): - if sched_item.iname: # (if not empty) - #print("entering loop, totals: \n", block_reg_totals, max_regs) - #current_loop_indices += 1 # TODO assumes all loops add 1 new index - # start a new block total - #block_reg_totals.append(current_loop_indices) - block_reg_totals.append(0) - # start a new estimator - reg_counters.append(RegisterUsageEstimator(knl)) - #print("entered loop, totals: \n", block_reg_totals, max_regs) - else: - print("Error, how does this happen?") - 1/0 - - elif isinstance(sched_item, LeaveLoop): - if sched_item.iname: # (if not empty) - #print("leaving loop, totals: \n", block_reg_totals, max_regs) - #current_loop_indices -= 1 # TODO assumes all loops add 1 new index - if block_reg_totals[-1] > max_regs: - max_regs = block_reg_totals[-1] - # pop to resume previous total - #block_reg_totals[-2] += block_reg_totals[-1] - block_reg_totals.pop() - reg_counters.pop() - #print("left loop, totals: \n", block_reg_totals, max_regs) - else: - print("Error, how does this happen?") - 1/0 - elif isinstance(sched_item, RunInstruction): - insn = knl.id_to_insn[sched_item.insn_id] - #print("instruction found: ", insn) - #print("pre insn totals: \n", block_reg_totals, max_regs) - block_reg_totals[-1] += reg_counters[-1](insn.assignee) + \ - reg_counters[-1](insn.expression) - #print("post insn totals: \n", block_reg_totals, max_regs) - # TODO don't count variables if they are loop indices? (also try this with ctr2) - - #print("finished schedule, totals: \n", block_reg_totals, max_regs) - # finished looping, check outer block - if block_reg_totals[-1] > max_regs: - max_regs = block_reg_totals[-1] - #print("final, totals: \n", block_reg_totals, max_regs) - - return max_regs - +''' #add all sub blocks to containing block #aka add everything together def get_regs_per_thread4(knl): @@ -1194,3 +1184,4 @@ def get_regs_per_thread4(knl): reg_counter(insn.expression) return regs+max_loop_indices +''' diff --git a/test/test_statistics.py b/test/test_statistics.py index fedf119dd..3ae1139e6 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -671,7 +671,7 @@ def test_barrier_counter_barriers(): barrier_count = poly.eval_with_dict(params) assert barrier_count == 50*10*2 -''' + def test_reg_counter_basic(): knl = lp.make_kernel( @@ -687,7 +687,6 @@ def test_reg_counter_basic(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) regs = get_regs_per_thread(knl) - 1/0 assert regs == 6 @@ -757,7 +756,7 @@ def test_reg_counter_bitwise(): g=np.int64, h=np.int64)) regs = get_regs_per_thread(knl) assert regs == 6 -''' + def test_all_counters_parallel_matmul(): -- GitLab