Skip to content
Snippets Groups Projects
Commit 6ff4c160 authored by James Stevens's avatar James Stevens
Browse files

cleaned up barrier counter

parent 94b7d7f1
No related branches found
No related tags found
No related merge requests found
......@@ -447,41 +447,35 @@ def get_DRAM_access_poly(knl): # for now just counting subscripts
def get_barrier_poly(knl):
from loopy.preprocess import preprocess_kernel, infer_unknown_types
from loopy.schedule import EnterLoop, LeaveLoop, Barrier
from operator import mul
knl = infer_unknown_types(knl, expect_completion=True)
knl = preprocess_kernel(knl)
knl = lp.get_one_scheduled_kernel(knl)
loop_iters = [1] # [isl.PwQPolynomial('[]->{ 1 }')]
barrier_poly = 0 # isl.PwQPolynomial('[]->{ 0 }')
from loopy.schedule import EnterLoop, LeaveLoop, Barrier
from operator import mul
print("TESTING... kernel sched: \n", knl.schedule)
iname_list = []
barrier_poly = isl.PwQPolynomial('{ 0 }') # 0
for sched_item in knl.schedule:
print("TESTING... sched_item: ", sched_item)
if isinstance(sched_item, EnterLoop):
print("TESTING... iname: ", sched_item.iname)
ct = count(knl, (
knl.get_inames_domain(sched_item.iname).
project_out_except(sched_item.iname, [dim_type.set])
))
if ct is not None:
loop_iters.append(ct)
if sched_item.iname: # (if not empty)
iname_list.append(sched_item.iname)
elif isinstance(sched_item, LeaveLoop):
print("TESTING... iname: ", sched_item.iname)
ct = count(knl, (
knl.get_inames_domain(sched_item.iname).
project_out_except(sched_item.iname, [dim_type.set])
))
if ct is not None:
loop_iters.pop()
if sched_item.iname: # (if not empty)
iname_list.pop()
elif isinstance(sched_item, Barrier):
print("TESTING... I FOUND A BARRIER!!!")
barrier_poly += reduce(mul, loop_iters)
print("TESTING... current iter list: \n", loop_iters)
print("TESTING... current iter product: \n", reduce(mul, loop_iters))
if iname_list: # (if iname_list is not empty)
ct = (count(knl, (
knl.get_inames_domain(iname_list).
project_out_except(iname_list, [dim_type.set])
)), )
barrier_poly += reduce(mul, ct)
else:
barrier_poly += isl.PwQPolynomial('{ 1 }')
'''
if not isinstance(barrier_poly, isl.PwQPolynomial):
# TODO figure out how to fix this
#TODO figure out better fix for this
string = "{"+str(barrier_poly)+"}"
return isl.PwQPolynomial(string)
'''
return barrier_poly
......@@ -424,7 +424,32 @@ def test_barrier_counter_basic():
l = 128
barrier_count = poly.eval_with_dict({'n': n, 'm': m, 'l': l})
assert barrier_count == 0
# TODO test kernels with barriers
def test_barrier_counter():
knl = lp.make_kernel(
"[n,m,l] -> {[i,k,j]: 0<=i<50 and 1<=k<98 and 0<=j<10}",
[
"""
c[i,j,k] = 2*a[i,j,k] {id=first}
e[i,j,k] = c[i,j,k+1]+c[i,j,k-1] {dep=first}
"""
], [
lp.TemporaryVariable("c", lp.auto, shape=(50, 10, 99)),
"..."
],
name="weird2",
)
knl = lp.add_and_infer_dtypes(knl, dict(a=np.int32))
knl = lp.split_iname(knl, "k", 128, outer_tag="g.0", inner_tag="l.0")
poly = get_barrier_poly(knl)
n = 512
m = 256
l = 128
barrier_count = poly.eval_with_dict({'n': n, 'm': m, 'l': l})
assert barrier_count == 1000
# TODO more barrier counting tests
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment