Skip to content
Snippets Groups Projects
Commit 4870c190 authored by James Stevens's avatar James Stevens
Browse files

reg usage estimator in somewhat usable state

parent ffec7cab
No related branches found
No related tags found
No related merge requests found
...@@ -550,23 +550,27 @@ class RegisterUsageEstimator(CombineMapper): ...@@ -550,23 +550,27 @@ class RegisterUsageEstimator(CombineMapper):
return 0 return 0
#''' #'''
def map_variable(self, expr): def map_variable(self, expr):
name = expr.name
if expr in self.vars_found: if expr in self.vars_found:
return 0 return 0
else:
self.vars_found.append(expr) self.vars_found.append(expr)
print("new var found: ", expr) if name in self.knl.temporary_variables:
print("knl.temp_vars: \n", self.knl.temporary_variables) if self.knl.temporary_variables[name].is_local:
print("found in temp_vars? ", expr.name in self.knl.temporary_variables) print("found temp var with local tag, not counting: ", expr) #TODO remove after debug
print("found in inames? ", expr.name in self.knl.all_inames)
#print("knl.vars: \n", self.knl.variables)
if expr.name in self.knl.temporary_variables:
print("local? ", self.knl.temporary_variables[expr.name].is_local)
#print("local? ", self.knl.temporary_variables[expr.name].is_local)
if "_dim_" in str(expr): #TODO how to remove block/thread size/id vars?
return 0 return 0
else: else:
return 1 return 1
elif name in self.knl.all_inames():
from loopy.kernel.data import AxisTag
if (self.knl.iname_to_tag.get(name) is None or
not isinstance(self.knl.iname_to_tag.get(name), AxisTag)):
#TODO use more specific positive instead of negative
return 1
else:
return 0
else:
return 1
map_tagged_variable = map_variable map_tagged_variable = map_variable
...@@ -917,6 +921,57 @@ def get_regs_per_thread(knl): ...@@ -917,6 +921,57 @@ def get_regs_per_thread(knl):
return get_regs_per_thread3_2(knl) return get_regs_per_thread3_2(knl)
def get_regs_per_thread3_2(knl):
"""Estimate registers per thread usage by a loopy kernel.
:parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated.
"""
from loopy.preprocess import preprocess_kernel, infer_unknown_types
from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction
from operator import mul
knl = infer_unknown_types(knl, expect_completion=True)
knl = preprocess_kernel(knl)
knl = lp.get_one_scheduled_kernel(knl)
max_regs = 0
block_reg_totals = [0]
reg_counters = [RegisterUsageEstimator(knl)]
# multiple counters to track nested sets of previously used iname+index combinations
for sched_item in knl.schedule:
if isinstance(sched_item, EnterLoop):
if sched_item.iname: # (if not empty)
block_reg_totals.append(0)
# start a new estimator
reg_counters.append(RegisterUsageEstimator(knl))
else:
print("Error, how does this happen?") #TODO
1/0
elif isinstance(sched_item, LeaveLoop):
if sched_item.iname: # (if not empty)
if block_reg_totals[-1] > max_regs:
max_regs = block_reg_totals[-1]
# pop to resume previous total
block_reg_totals.pop()
reg_counters.pop()
else:
print("Error, how does this happen?") #TODO
1/0
elif isinstance(sched_item, RunInstruction):
insn = knl.id_to_insn[sched_item.insn_id]
block_reg_totals[-1] += reg_counters[-1](insn.assignee) + \
reg_counters[-1](insn.expression)
# finished looping, check outer block
if block_reg_totals[-1] > max_regs:
max_regs = block_reg_totals[-1]
return max_regs
'''
# map_var and map_tagged_var returned 1, no checking for any duplication # map_var and map_tagged_var returned 1, no checking for any duplication
def get_regs_per_thread1(knl): def get_regs_per_thread1(knl):
...@@ -1084,74 +1139,9 @@ def get_regs_per_thread3(knl): ...@@ -1084,74 +1139,9 @@ def get_regs_per_thread3(knl):
#print("final, totals: \n", block_reg_totals, max_regs) #print("final, totals: \n", block_reg_totals, max_regs)
return max_regs return max_regs
'''
def get_regs_per_thread3_2(knl): '''
"""Estimate registers per thread usage by a loopy kernel.
:parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated.
"""
from loopy.preprocess import preprocess_kernel, infer_unknown_types
from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction
from operator import mul
knl = infer_unknown_types(knl, expect_completion=True)
knl = preprocess_kernel(knl)
knl = lp.get_one_scheduled_kernel(knl)
#print(knl)
max_regs = 0
#current_loop_indices = 0
block_reg_totals = [0]
reg_counters = [RegisterUsageEstimator(knl)]
# multiple counters to track nested sets of previously used iname+index combinations
for sched_item in knl.schedule:
if isinstance(sched_item, EnterLoop):
if sched_item.iname: # (if not empty)
#print("entering loop, totals: \n", block_reg_totals, max_regs)
#current_loop_indices += 1 # TODO assumes all loops add 1 new index
# start a new block total
#block_reg_totals.append(current_loop_indices)
block_reg_totals.append(0)
# start a new estimator
reg_counters.append(RegisterUsageEstimator(knl))
#print("entered loop, totals: \n", block_reg_totals, max_regs)
else:
print("Error, how does this happen?")
1/0
elif isinstance(sched_item, LeaveLoop):
if sched_item.iname: # (if not empty)
#print("leaving loop, totals: \n", block_reg_totals, max_regs)
#current_loop_indices -= 1 # TODO assumes all loops add 1 new index
if block_reg_totals[-1] > max_regs:
max_regs = block_reg_totals[-1]
# pop to resume previous total
#block_reg_totals[-2] += block_reg_totals[-1]
block_reg_totals.pop()
reg_counters.pop()
#print("left loop, totals: \n", block_reg_totals, max_regs)
else:
print("Error, how does this happen?")
1/0
elif isinstance(sched_item, RunInstruction):
insn = knl.id_to_insn[sched_item.insn_id]
#print("instruction found: ", insn)
#print("pre insn totals: \n", block_reg_totals, max_regs)
block_reg_totals[-1] += reg_counters[-1](insn.assignee) + \
reg_counters[-1](insn.expression)
#print("post insn totals: \n", block_reg_totals, max_regs)
# TODO don't count variables if they are loop indices? (also try this with ctr2)
#print("finished schedule, totals: \n", block_reg_totals, max_regs)
# finished looping, check outer block
if block_reg_totals[-1] > max_regs:
max_regs = block_reg_totals[-1]
#print("final, totals: \n", block_reg_totals, max_regs)
return max_regs
#add all sub blocks to containing block #add all sub blocks to containing block
#aka add everything together #aka add everything together
def get_regs_per_thread4(knl): def get_regs_per_thread4(knl):
...@@ -1194,3 +1184,4 @@ def get_regs_per_thread4(knl): ...@@ -1194,3 +1184,4 @@ def get_regs_per_thread4(knl):
reg_counter(insn.expression) reg_counter(insn.expression)
return regs+max_loop_indices return regs+max_loop_indices
'''
...@@ -671,7 +671,7 @@ def test_barrier_counter_barriers(): ...@@ -671,7 +671,7 @@ def test_barrier_counter_barriers():
barrier_count = poly.eval_with_dict(params) barrier_count = poly.eval_with_dict(params)
assert barrier_count == 50*10*2 assert barrier_count == 50*10*2
'''
def test_reg_counter_basic(): def test_reg_counter_basic():
knl = lp.make_kernel( knl = lp.make_kernel(
...@@ -687,7 +687,6 @@ def test_reg_counter_basic(): ...@@ -687,7 +687,6 @@ def test_reg_counter_basic():
knl = lp.add_and_infer_dtypes(knl, knl = lp.add_and_infer_dtypes(knl,
dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
regs = get_regs_per_thread(knl) regs = get_regs_per_thread(knl)
1/0
assert regs == 6 assert regs == 6
...@@ -757,7 +756,7 @@ def test_reg_counter_bitwise(): ...@@ -757,7 +756,7 @@ def test_reg_counter_bitwise():
g=np.int64, h=np.int64)) g=np.int64, h=np.int64))
regs = get_regs_per_thread(knl) regs = get_regs_per_thread(knl)
assert regs == 6 assert regs == 6
'''
def test_all_counters_parallel_matmul(): def test_all_counters_parallel_matmul():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment