From 7d33b5c7d6d920ffee95429b7c9085b0afab8497 Mon Sep 17 00:00:00 2001 From: James Stevens <jdsteve2@illinois.edu> Date: Wed, 30 Sep 2015 21:34:27 -0500 Subject: [PATCH] experimenting with reg usage estimator --- loopy/statistics.py | 118 +++++++++++++++++++++++++++++++--------- test/test_statistics.py | 10 ++-- 2 files changed, 98 insertions(+), 30 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index cd7ff8d79..7b1a5c7da 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -410,19 +410,18 @@ class RegisterUsageEstimator(CombineMapper): def map_constant(self, expr): return 0 - - def map_tagged_variable(self, expr): - print("tagged_var found, what to do here? ", expr) #TODO - return 0 - + #''' def map_variable(self, expr): - #print("var: ", expr) if expr in self.vars_found: - #print("var already in list") #TODO delete all these print statements return 0 else: self.vars_found.append(expr) - return 1 + if "_dim_" in str(expr): #TODO how to remove block/thread size/id vars? + return 0 + else: + return 1 + + map_tagged_variable = map_variable #map_variable = map_tagged_variable map_call = map_constant # TODO what is this? @@ -431,32 +430,33 @@ class RegisterUsageEstimator(CombineMapper): name = expr.aggregate.name # name of array if name in self.knl.arg_dict: + # not a temporary variable array = self.knl.arg_dict[name] + elif self.knl.temporary_variables[name].is_local: + # temp var is in shared mem + return 0 + self.rec(expr.index) + elif (expr.index, expr.aggregate) in self.subs_found: + # temp var is NOT shared, but already counted + return 0 + self.rec(expr.index) else: - # this is a temporary variable - #print("subscript, temp var found: ", expr, expr.index, expr.aggregate) - #print((expr.index, expr.aggregate) in self.subs_found) - return 0 - # TODO if this is not in shared, count it - if (expr.index, expr.aggregate) in self.subs_found: - return 0 - else: - self.subs_found.append((expr.index, expr.aggregate)) - return 1 # TODO +self.rec(expr.index)? + # temp var is NOT shared and NOT already counted + self.subs_found.append((expr.index, expr.aggregate)) + return 1 + self.rec(expr.index) + + # expr is not a temporary variable if not isinstance(array, lp.GlobalArg): + print("debug... When does this happen? ", expr, array) + 1/0 # this array is not in global memory - return 0 # TODO is this right? recurse on index? + return 1 + self.rec(expr.index) # TODO # this is a global mem access - #print("subscript, global var found: ", expr, expr.index, expr.aggregate) - #print((expr.index, expr.aggregate) in self.subs_found) - #self.subs_found.append((expr.index, expr.aggregate)) if (expr.index, expr.aggregate) in self.subs_found: - return 0 + return 0 + self.rec(expr.index) else: self.subs_found.append((expr.index, expr.aggregate)) - return 1 # TODO +self.rec(expr.index)? + return 1 + self.rec(expr.index) def map_sum(self, expr): if expr.children: @@ -749,7 +749,7 @@ def get_barrier_poly(knl): def get_regs_per_thread(knl): - return get_regs_per_thread4(knl) + return get_regs_per_thread3_2(knl) # map_var and map_tagged_var returned 1, no checking for any duplication @@ -920,7 +920,75 @@ def get_regs_per_thread3(knl): return max_regs +def get_regs_per_thread3_2(knl): + + """Estimate registers per thread usage by a loopy kernel. + + :parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated. + + """ + + from loopy.preprocess import preprocess_kernel, infer_unknown_types + from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction + from operator import mul + knl = infer_unknown_types(knl, expect_completion=True) + knl = preprocess_kernel(knl) + knl = lp.get_one_scheduled_kernel(knl) + #print(knl) + max_regs = 0 + #current_loop_indices = 0 + block_reg_totals = [0] + reg_counters = [RegisterUsageEstimator(knl)] + # multiple counters to track nested sets of previously used iname+index combinations + + for sched_item in knl.schedule: + if isinstance(sched_item, EnterLoop): + if sched_item.iname: # (if not empty) + #print("entering loop, totals: \n", block_reg_totals, max_regs) + #current_loop_indices += 1 # TODO assumes all loops add 1 new index + # start a new block total + #block_reg_totals.append(current_loop_indices) + block_reg_totals.append(0) + # start a new estimator + reg_counters.append(RegisterUsageEstimator(knl)) + #print("entered loop, totals: \n", block_reg_totals, max_regs) + else: + print("Error, how does this happen?") + 1/0 + + elif isinstance(sched_item, LeaveLoop): + if sched_item.iname: # (if not empty) + #print("leaving loop, totals: \n", block_reg_totals, max_regs) + #current_loop_indices -= 1 # TODO assumes all loops add 1 new index + if block_reg_totals[-1] > max_regs: + max_regs = block_reg_totals[-1] + # pop to resume previous total + #block_reg_totals[-2] += block_reg_totals[-1] + block_reg_totals.pop() + reg_counters.pop() + #print("left loop, totals: \n", block_reg_totals, max_regs) + else: + print("Error, how does this happen?") + 1/0 + elif isinstance(sched_item, RunInstruction): + insn = knl.id_to_insn[sched_item.insn_id] + #print("instruction found: ", insn) + #print("pre insn totals: \n", block_reg_totals, max_regs) + block_reg_totals[-1] += reg_counters[-1](insn.assignee) + \ + reg_counters[-1](insn.expression) + #print("post insn totals: \n", block_reg_totals, max_regs) + # TODO don't count variables if they are loop indices? (also try this with ctr2) + + #print("finished schedule, totals: \n", block_reg_totals, max_regs) + # finished looping, check outer block + if block_reg_totals[-1] > max_regs: + max_regs = block_reg_totals[-1] + #print("final, totals: \n", block_reg_totals, max_regs) + + return max_regs + #add all sub blocks to containing block +#aka add everything together def get_regs_per_thread4(knl): """Estimate registers per thread usage by a loopy kernel. diff --git a/test/test_statistics.py b/test/test_statistics.py index 985aeb1b9..7d89a521d 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -526,7 +526,7 @@ def test_reg_counter_basic(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) regs = get_regs_per_thread(knl) - assert regs == 9 + assert regs == 6 def test_reg_counter_reduction(): @@ -540,7 +540,7 @@ def test_reg_counter_reduction(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) regs = get_regs_per_thread(knl) - assert regs == 7 + assert regs == 6 def test_reg_counter_logic(): @@ -556,7 +556,7 @@ def test_reg_counter_logic(): knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64)) regs = get_regs_per_thread(knl) - assert regs == 7 + assert regs == 6 def test_reg_counter_specialops(): @@ -574,7 +574,7 @@ def test_reg_counter_specialops(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) regs = get_regs_per_thread(knl) - assert regs == 9 + assert regs == 6 def test_reg_counter_bitwise(): @@ -594,7 +594,7 @@ def test_reg_counter_bitwise(): a=np.int32, b=np.int32, g=np.int64, h=np.int64)) regs = get_regs_per_thread(knl) - assert regs == 11 + assert regs == 6 def test_all_counters_parallel_matmul(): -- GitLab