Skip to content
Snippets Groups Projects
Commit 5986977e authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Improve, document gather_access_footprints, add gather_access_footprint_bytes

parent 598a2415
No related branches found
No related tags found
No related merge requests found
......@@ -108,7 +108,8 @@ from loopy.schedule import generate_loop_schedules, get_one_scheduled_kernel
from loopy.statistics import (get_op_poly, sum_ops_to_dtypes,
get_gmem_access_poly,
get_DRAM_access_poly, get_barrier_poly, stringify_stats_mapping,
sum_mem_access_to_bytes)
sum_mem_access_to_bytes,
gather_access_footprints, gather_access_footprint_bytes)
from loopy.codegen import generate_code, generate_body
from loopy.compiled import CompiledKernel
from loopy.options import Options
......@@ -200,6 +201,7 @@ __all__ = [
"get_op_poly", "sum_ops_to_dtypes", "get_gmem_access_poly",
"get_DRAM_access_poly",
"get_barrier_poly", "stringify_stats_mapping", "sum_mem_access_to_bytes",
"gather_access_footprints", "gather_access_footprint_bytes",
"CompiledKernel",
......
......@@ -32,7 +32,7 @@ from pytools import memoize_in
from pymbolic.mapper import CombineMapper
from functools import reduce
from loopy.kernel.data import MultiAssignmentBase
from loopy.diagnostic import warn
from loopy.diagnostic import warn, LoopyError
__doc__ = """
......@@ -47,6 +47,9 @@ __doc__ = """
.. autofunction:: get_barrier_poly
.. autofunction:: gather_access_footprints
.. autofunction:: gather_access_footprint_bytes
"""
......@@ -415,9 +418,10 @@ class GlobalSubscriptCounter(CombineMapper):
# {{{ AccessFootprintGatherer
class AccessFootprintGatherer(CombineMapper):
def __init__(self, kernel, domain):
def __init__(self, kernel, domain, ignore_uncountable=False):
self.kernel = kernel
self.domain = domain
self.ignore_uncountable = ignore_uncountable
@staticmethod
def combine(values):
......@@ -456,10 +460,17 @@ class AccessFootprintGatherer(CombineMapper):
self.kernel.assumptions)
except isl.Error:
# Likely: index was non-linear, nothing we can do.
return
if self.ignore_uncountable:
return {}
else:
raise LoopyError("failed to gather footprint: %s" % expr)
except TypeError:
# Likely: index was non-linear, nothing we can do.
return
if self.ignore_uncountable:
return {}
else:
raise LoopyError("failed to gather footprint: %s" % expr)
from pymbolic.primitives import Variable
assert isinstance(expr.aggregate, Variable)
......@@ -838,8 +849,16 @@ def get_barrier_poly(knl):
# {{{ gather_access_footprints
def gather_access_footprints(kernel):
# TODO: Docs
def gather_access_footprints(kernel, ignore_uncountable=False):
"""Return a dictionary mapping ``(var_name, direction)``
to :class:`islpy.Set` instances capturing which indices
of each the array *var_name* are read/written (where
*direction* is either ``read`` or ``write``.
:arg ignore_uncountable: If *True*, an error will be raised for
accesses on which the footprint cannot be determined (e.g.
data-dependent or nonlinear indices)
"""
from loopy.preprocess import preprocess_kernel, infer_unknown_types
kernel = infer_unknown_types(kernel, expect_completion=True)
......@@ -859,7 +878,8 @@ def gather_access_footprints(kernel):
inames_domain = kernel.get_inames_domain(insn_inames)
domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
afg = AccessFootprintGatherer(kernel, domain)
afg = AccessFootprintGatherer(kernel, domain,
ignore_uncountable=ignore_uncountable)
for assignee in insn.assignees:
write_footprints.append(afg(insn.assignees))
......@@ -878,6 +898,35 @@ def gather_access_footprints(kernel):
return result
def gather_access_footprint_bytes(kernel, ignore_uncountable=False):
"""Return a dictionary mapping ``(var_name, direction)`` to
:class:`islpy.PwQPolynomial` instances capturing the number of bytes are
read/written (where *direction* is either ``read`` or ``write`` on array
*var_name*
:arg ignore_uncountable: If *True*, an error will be raised for
accesses on which the footprint cannot be determined (e.g.
data-dependent or nonlinear indices)
"""
result = {}
fp = gather_access_footprints(kernel, ignore_uncountable=ignore_uncountable)
for key, var_fp in fp.items():
vname, direction = key
var_descr = kernel.get_var_descriptor(vname)
bytes_transferred = (
int(var_descr.dtype.numpy_dtype.itemsize)
* count(kernel, var_fp))
if key in result:
result[key] += bytes_transferred
else:
result[key] = bytes_transferred
return result
# }}}
# vim: foldmethod=marker
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment