Skip to content
Snippets Groups Projects
Commit dd2b6570 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Reimplement add_prefetch().

parent 2513ee98
No related branches found
No related tags found
No related merge requests found
......@@ -69,8 +69,6 @@ TODO
a <- cse(reduce(stuff))
- reimplement add_prefetch
- user interface for dim length prescription
- How to determine which variables need to be duplicated for ILP?
......@@ -90,6 +88,8 @@ TODO
Dealt with
^^^^^^^^^^
- reimplement add_prefetch
- Flag, exploit idempotence
- Some things involving CSEs might be impossible to schedule
......
......@@ -266,9 +266,9 @@ def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=Non
# the iname is *not* a dependency of the fetch expression
if iname in duplicate_inames:
raise RuntimeError("duplicating an iname "
"that the CSE does not depend on "
"does not make sense")
raise RuntimeError("duplicating an iname ('%s')"
"that the CSE ('%s') does not depend on "
"does not make sense" % (iname, expr.child))
# Which iname dependencies are carried over from CSE host
# to the CSE compute instruction?
......@@ -495,72 +495,39 @@ def check_kernels(kernel_gen, parameters, kill_level_min=3,
# }}}
# {{{ high-level modifiers
# {{{ convenience
def get_input_access_descriptors(kernel):
"""Return a dictionary mapping input vectors to
a list of input access descriptor. An input access
descriptor is a tuple (input_vec, index_expr).
"""
1/0 # broken
from loopy.symbolic import VariableIndexExpressionCollector
from pytools import flatten
result = {}
for ivec in kernel.input_vectors():
result[ivec] = set(
(ivec, iexpr)
for iexpr in flatten(
VariableIndexExpressionCollector(ivec)(expression)
for lvalue, expression in kernel.instructions
))
return result
def add_prefetch(kernel, input_access_descr, fetch_dims, loc_fetch_axes={}):
"""
:arg input_access_descr: see :func:`get_input_access_descriptors`.
May also be the name of the variable if there is only one
reference to that variable.
:arg fetch_dims: loop dimensions indexing the input variable on which
the prefetch is to be carried out.
"""
1/0 # broken
if isinstance(input_access_descr, str):
var_name = input_access_descr
var_iads = get_input_access_descriptors(kernel)[var_name]
if len(var_iads) != 1:
raise ValueError("input access descriptor for variable %s is "
"not unique" % var_name)
input_access_descr, = var_iads
def parse_fetch_dim(iname):
if isinstance(iname, str):
iname = (iname,)
return tuple(kernel.tag_or_iname_to_iname(s) for s in iname)
fetch_dims = [parse_fetch_dim(fd) for fd in fetch_dims]
ivec, iexpr = input_access_descr
new_prefetch = getattr(kernel, "prefetch", {}).copy()
if input_access_descr in new_prefetch:
raise ValueError("a prefetch descriptor for the input access %s[%s] "
"already exists" % (ivec, iexpr))
from loopy.prefetch import LocalMemoryPrefetch
new_prefetch[input_access_descr] = LocalMemoryPrefetch(
kernel=kernel,
input_vector=ivec,
index_expr=iexpr,
fetch_dims=fetch_dims,
loc_fetch_axes=loc_fetch_axes)
return kernel.copy(prefetch=new_prefetch)
def add_prefetch(kernel, var_name, fetch_dims=[]):
used_cse_tags = set()
def map_cse(expr, rec):
used_cse_tags.add(expr.tag)
rec(expr.child)
new_cse_tags = set()
def get_unique_cse_tag():
from loopy.tools import generate_unique_possibilities
for cse_tag in generate_unique_possibilities(prefix="fetch_"+var_name):
if cse_tag not in used_cse_tags:
used_cse_tags.add(cse_tag)
new_cse_tags.add(cse_tag)
return cse_tag
from loopy.symbolic import VariableFetchCSEMapper
vf_cse_mapper = VariableFetchCSEMapper(var_name, get_unique_cse_tag)
kernel = kernel.copy(instructions=[
insn.copy(expression=vf_cse_mapper(insn.expression))
for insn in kernel.instructions])
if var_name in kernel.arg_dict:
dtype = kernel.arg_dict[var_name].dtype
else:
dtype = kernel.temporary_variables[var_name].dtype
for cse_tag in new_cse_tags:
kernel = realize_cse(kernel, cse_tag, dtype, fetch_dims)
return kernel
# }}}
......
......@@ -506,6 +506,30 @@ class IndexVariableFinder(CombineMapper):
# }}}
# {{{ variable-fetch CSE mapper
class VariableFetchCSEMapper(IdentityMapper):
def __init__(self, var_name, cse_tag_getter):
self.var_name = var_name
self.cse_tag_getter = cse_tag_getter
def map_variable(self, expr):
from pymbolic.primitives import CommonSubexpression
if expr.name == self.var_name:
return CommonSubexpression(expr, self.cse_tag_getter())
else:
return IdentityMapper.map_variable(self, expr)
def map_subscript(self, expr):
from pymbolic.primitives import CommonSubexpression, Variable, Subscript
if (isinstance(expr.aggregate, Variable)
and expr.aggregate.name == self.var_name):
return CommonSubexpression(
Subscript(expr.aggregate, self.rec(expr.index)), self.cse_tag_getter())
else:
return IdentityMapper.map_subscript(self, expr)
# }}}
......
......@@ -258,7 +258,7 @@ def test_rank_one(ctx_factory):
knl = lp.LoopKernel(ctx.devices[0],
"[n] -> {[i,j]: 0<=i,j<n}",
[
"label: c[i, j] = cse(a[i], a)*cse(b[j], b)"
"label: c[i, j] = a[i]*b[j]"
],
[
lp.ArrayArg("a", dtype, shape=(n,), order=order),
......@@ -269,8 +269,8 @@ def test_rank_one(ctx_factory):
name="rank_one", assumptions="n >= 16")
def variant_1(knl):
knl = lp.realize_cse(knl, "a", dtype)
knl = lp.realize_cse(knl, "b", dtype)
knl = lp.add_prefetch(knl, "a")
knl = lp.add_prefetch(knl, "b")
return knl
def variant_2(knl):
......@@ -279,8 +279,8 @@ def test_rank_one(ctx_factory):
knl = lp.split_dimension(knl, "j", 16,
outer_tag="g.1", inner_tag="l.1")
knl = lp.realize_cse(knl, "a", dtype)
knl = lp.realize_cse(knl, "b", dtype)
knl = lp.add_prefetch(knl, "a")
knl = lp.add_prefetch(knl, "b")
return knl
def variant_3(knl):
......@@ -289,8 +289,8 @@ def test_rank_one(ctx_factory):
knl = lp.split_dimension(knl, "j", 16,
outer_tag="g.1", inner_tag="l.1")
knl = lp.realize_cse(knl, "a", dtype, ["i_inner"])
knl = lp.realize_cse(knl, "b", dtype, ["j_inner"])
knl = lp.add_prefetch(knl, "a", ["i_inner"])
knl = lp.add_prefetch(knl, "b", ["j_inner"])
return knl
def variant_4(knl):
......@@ -299,8 +299,8 @@ def test_rank_one(ctx_factory):
knl = lp.split_dimension(knl, "j", 256,
outer_tag="g.1", slabs=(0, -1))
knl = lp.realize_cse(knl, "a", dtype, ["i_inner"])
knl = lp.realize_cse(knl, "b", dtype, ["j_inner"])
knl = lp.add_prefetch(knl, "a", ["i_inner"])
knl = lp.add_prefetch(knl, "b", ["j_inner"])
knl = lp.split_dimension(knl, "i_inner", 16,
inner_tag="l.0")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment