Skip to content
Snippets Groups Projects
Commit 92bf9408 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Implement access footprint gathering

parent 6e649f7b
No related branches found
No related tags found
No related merge requests found
......@@ -30,6 +30,8 @@ from islpy import dim_type
import islpy as isl
from pymbolic.mapper import CombineMapper
from functools import reduce
from loopy.kernel.data import Assignment
from loopy.diagnostic import warn
# {{{ ToCountMap
......@@ -392,6 +394,65 @@ class GlobalSubscriptCounter(CombineMapper):
# }}}
# {{{ AccessFootprintGatherer
class AccessFootprintGatherer(CombineMapper):
def __init__(self, kernel, domain):
self.kernel = kernel
self.domain = domain
@staticmethod
def combine(values):
assert values
def merge_dicts(a, b):
result = a.copy()
for var_name, footprint in six.iteritems(b):
if var_name in result:
result[var_name] = result[var_name] | footprint
else:
result[var_name] = footprint
return result
from functools import reduce
return reduce(merge_dicts, values)
def map_constant(self, expr):
return {}
def map_variable(self, expr):
return {}
def map_subscript(self, expr):
subscript = expr.index
if not isinstance(subscript, tuple):
subscript = (subscript,)
from loopy.symbolic import get_access_range
try:
access_range = get_access_range(self.domain, subscript,
self.kernel.assumptions)
except isl.Error:
# Likely: index was non-linear, nothing we can do.
return
except TypeError:
# Likely: index was non-linear, nothing we can do.
return
from pymbolic.primitives import Variable
assert isinstance(expr.aggregate, Variable)
return self.combine([
self.rec(expr.index),
{expr.aggregate.name: access_range}])
# }}}
# {{{ count
def count(kernel, bset):
......@@ -649,4 +710,48 @@ def get_barrier_poly(knl):
# }}}
# {{{ gather_access_footprints
def gather_access_footprints(kernel):
# TODO: Docs
from loopy.preprocess import preprocess_kernel, infer_unknown_types
kernel = infer_unknown_types(kernel, expect_completion=True)
kernel = preprocess_kernel(kernel)
write_footprints = []
read_footprints = []
for insn in kernel.instructions:
if not isinstance(insn, Assignment):
warn(kernel, "count_non_assignment",
"Non-assignment instruction encountered in "
"gather_access_footprints, not counted")
continue
insn_inames = kernel.insn_inames(insn)
inames_domain = kernel.get_inames_domain(insn_inames)
domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
afg = AccessFootprintGatherer(kernel, domain)
write_footprints.append(afg(insn.assignee))
read_footprints.append(afg(insn.expression))
write_footprints = AccessFootprintGatherer.combine(write_footprints)
read_footprints = AccessFootprintGatherer.combine(read_footprints)
result = {}
for vname, footprint in six.iteritems(write_footprints):
result[(vname, "write")] = footprint
for vname, footprint in six.iteritems(read_footprints):
result[(vname, "read")] = footprint
return result
# }}}
# vim: foldmethod=marker
......@@ -22,6 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import six
import sys
from pyopencl.tools import ( # noqa
pytest_generate_tests_for_pyopencl
......@@ -595,6 +596,22 @@ def test_all_counters_parallel_matmul():
assert f32coal == n*l
def test_gather_access_footprint():
knl = lp.make_kernel(
"{[i,k,j]: 0<=i,j,k<n}",
[
"c[i, j] = sum(k, a[i, k]*b[k, j]) + a[i,j]"
],
name="matmul", assumptions="n >= 1")
knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
from loopy.statistics import gather_access_footprints, count
fp = gather_access_footprints(knl)
for key, footprint in six.iteritems(fp):
print(key, count(knl, footprint))
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment