diff --git a/loopy/statistics.py b/loopy/statistics.py index 2b10179ea802fb53d79f8a24cfd53599a757d79b..992c95a4e9aee58786220197d7b3372c4b046614 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -30,6 +30,8 @@ from islpy import dim_type import islpy as isl from pymbolic.mapper import CombineMapper from functools import reduce +from loopy.kernel.data import Assignment +from loopy.diagnostic import warn # {{{ ToCountMap @@ -392,6 +394,65 @@ class GlobalSubscriptCounter(CombineMapper): # }}} +# {{{ AccessFootprintGatherer + +class AccessFootprintGatherer(CombineMapper): + def __init__(self, kernel, domain): + self.kernel = kernel + self.domain = domain + + @staticmethod + def combine(values): + assert values + + def merge_dicts(a, b): + result = a.copy() + + for var_name, footprint in six.iteritems(b): + if var_name in result: + result[var_name] = result[var_name] | footprint + else: + result[var_name] = footprint + + return result + + from functools import reduce + return reduce(merge_dicts, values) + + def map_constant(self, expr): + return {} + + def map_variable(self, expr): + return {} + + def map_subscript(self, expr): + subscript = expr.index + + if not isinstance(subscript, tuple): + subscript = (subscript,) + + from loopy.symbolic import get_access_range + + try: + access_range = get_access_range(self.domain, subscript, + self.kernel.assumptions) + except isl.Error: + # Likely: index was non-linear, nothing we can do. + return + except TypeError: + # Likely: index was non-linear, nothing we can do. + return + + from pymbolic.primitives import Variable + assert isinstance(expr.aggregate, Variable) + + return self.combine([ + self.rec(expr.index), + {expr.aggregate.name: access_range}]) + +# }}} + + # {{{ count def count(kernel, bset): @@ -649,4 +710,48 @@ def get_barrier_poly(knl): # }}} + +# {{{ gather_access_footprints + +def gather_access_footprints(kernel): + # TODO: Docs + + from loopy.preprocess import preprocess_kernel, infer_unknown_types + kernel = infer_unknown_types(kernel, expect_completion=True) + kernel = preprocess_kernel(kernel) + + write_footprints = [] + read_footprints = [] + + for insn in kernel.instructions: + if not isinstance(insn, Assignment): + warn(kernel, "count_non_assignment", + "Non-assignment instruction encountered in " + "gather_access_footprints, not counted") + continue + + insn_inames = kernel.insn_inames(insn) + inames_domain = kernel.get_inames_domain(insn_inames) + domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) + + afg = AccessFootprintGatherer(kernel, domain) + + write_footprints.append(afg(insn.assignee)) + read_footprints.append(afg(insn.expression)) + + write_footprints = AccessFootprintGatherer.combine(write_footprints) + read_footprints = AccessFootprintGatherer.combine(read_footprints) + + result = {} + + for vname, footprint in six.iteritems(write_footprints): + result[(vname, "write")] = footprint + + for vname, footprint in six.iteritems(read_footprints): + result[(vname, "read")] = footprint + + return result + +# }}} + # vim: foldmethod=marker diff --git a/test/test_statistics.py b/test/test_statistics.py index 0dffe5c3575237cab8f518ba95a33f74a3bbe840..2d9096e381b7c914b14ce74070a4ed64b82e2eff 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -22,6 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import six import sys from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl @@ -595,6 +596,22 @@ def test_all_counters_parallel_matmul(): assert f32coal == n*l +def test_gather_access_footprint(): + knl = lp.make_kernel( + "{[i,k,j]: 0<=i,j,k<n}", + [ + "c[i, j] = sum(k, a[i, k]*b[k, j]) + a[i,j]" + ], + name="matmul", assumptions="n >= 1") + knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) + + from loopy.statistics import gather_access_footprints, count + fp = gather_access_footprints(knl) + + for key, footprint in six.iteritems(fp): + print(key, count(knl, footprint)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])