WIP: Buses per warp estimate
I think my buses per warp calculation is correct now. @inducer Can you look at it and see if you agree? It goes something like this:
- Find the inames with local tags and figure out what their coefficients would be if the index were flattened, i.e., find C2, C1, and C0 in var[... C2lid2 + C1lid1 + C0*lid0]. If no local tag is found, set bpw_est=1.
- Figure out the group shape and set non-existent axes to size 1. If group shape is a variable, raise error (for now).
- If any of the coefficients above are variables, just assume they will be larger than the warpsize and set them to warpsize+1 for the purposes of the calculation (so far this seems to perform as intended)
- Calculate. This single line of code where it all comes together took 90% of the time I spent on this... but don't hesitate to tell me what's wrong with my masterpiece:
final_bpw_est = min(
wsize,
ceil(
(
l2coeff*(min(lsize[2], ceil(wsize / (lsize[0]*lsize[1]))) - 1) +
l1coeff*(min(lsize[1], ceil(wsize / lsize[0])) - 1) +
l0coeff*(min(lsize[0], wsize) - 1) +
1
) / wsize
)
)
This is in GlobalMemAccessCounter.map_subscript: https://gitlab.tiker.net/jdsteve2/loopy/blob/buses-per-warp-estimate/loopy/statistics.py#L844
Notes:
- The bpw estimate assumes there are at least 32 threads per group.
- Could get_grid_size_upper_bounds be changed so that it just always gives me a 3D index with 1s for absent axes?
- If my bpw calculation is correct, organizing our mem accesses into properties based on this number will group things significantly differently than our previous stride groupings. I hope that's a good thing :-).
- In figuring out the bpw calculation, I've found some situations where our stride calculation can probably be improved, but for now the stride calculation has not changed.
Edited by James Stevens