Skip to content
Snippets Groups Projects
Commit d20e0267 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add footprint generators for prefetch.

parent fd78a38d
No related branches found
No related tags found
No related merge requests found
......@@ -61,13 +61,15 @@ To-do
- add_prefetch gets a flag to separate out each access
- Making parameters run-time varying, substituting values that
depend on other inames?
- Allow parameters to be varying during run-time varying, substituting values
that depend on other inames?
- Fix all tests
- Scalar insn priority
- : in prefetches
Future ideas
^^^^^^^^^^^^
......@@ -75,9 +77,6 @@ Future ideas
- String instructions?
- How is intra-instruction ordering of ILP loops going to be determined?
(taking into account that it could vary even per-instruction?)
- Barriers for data exchanged via global vars?
- Float4 joining on fetch/store?
......@@ -88,8 +87,6 @@ Future ideas
- Better for loop bound generation
-> Try a triangular loop
- Sharing of checks across ILP instances
- Eliminate the first (pre-)barrier in a loop.
- Generate automatic test against sequential code.
......@@ -115,6 +112,11 @@ Future ideas
Dealt with
^^^^^^^^^^
- How is intra-instruction ordering of ILP loops going to be determined?
(taking into account that it could vary even per-instruction?)
- Sharing of checks across ILP instances
- Differentiate ilp.unr from ilp.seq
- Allow complex-valued arithmetic, despite CL's best efforts.
......
......@@ -205,6 +205,8 @@ def split_dimension(kernel, split_iname, inner_length,
if split_iname not in kernel.all_inames():
raise ValueError("cannot split loop for unknown variable '%s'" % split_iname)
applied_substitutions = kernel.applied_substitutions[:]
if outer_iname is None:
outer_iname = split_iname+"_outer"
if inner_iname is None:
......@@ -248,6 +250,7 @@ def split_dimension(kernel, split_iname, inner_length,
new_insns = []
for insn in kernel.instructions:
subst_map = {var(split_iname): new_loop_index}
applied_substitutions.append(subst_map)
from loopy.symbolic import SubstitutionMapper
subst_mapper = SubstitutionMapper(subst_map.get)
......@@ -277,6 +280,7 @@ def split_dimension(kernel, split_iname, inner_length,
.copy(domain=new_domain,
iname_slab_increments=iname_slab_increments,
instructions=new_insns,
applied_substitutions=applied_substitutions,
))
return tag_dimensions(result, {outer_iname: outer_tag, inner_iname: inner_tag})
......@@ -367,7 +371,10 @@ def join_dimensions(kernel, inames, new_iname=None, tag=AutoFitLocalIndexTag()):
result = (kernel
.map_expressions(subst_map, exclude_instructions=True)
.copy(instructions=new_insns, domain=new_domain))
.copy(
instructions=new_insns, domain=new_domain,
applied_substitutions=kernel.applied_substitutions + [subst_map]
))
return tag_dimensions(result, {new_iname: tag})
......@@ -417,8 +424,43 @@ def tag_dimensions(kernel, iname_to_tag, force=False):
# {{{ convenience: add_prefetch
def add_prefetch(kernel, var_name, sweep_dims=[], dim_arg_names=None,
default_tag="l.auto", rule_name=None):
def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
default_tag="l.auto", rule_name=None, footprint_indices=None):
"""Prefetch all accesses to the variable *var_name*, with all accesses
being swept through *sweep_inames*.
:ivar dim_arg_names: List of names representing each fetch axis.
:ivar rule_name: base name of the generated temporary variable.
:ivar footprint_indices: A list of tuples indicating the index set used
to generate the footprint.
If only one such set of indices is desired, this may also be specified
directly by putting an index expression into *var_name*. Substitutions
such as those occurring in dimension splits are recorded and also
applied to these indices.
"""
# {{{ fish indexing out of var_name and into sweep_indices
from loopy.symbolic import parse
parsed_var_name = parse(var_name)
from pymbolic.primitives import Variable, Subscript
if isinstance(parsed_var_name, Variable):
# nothing to see
pass
elif isinstance(parsed_var_name, Subscript):
if footprint_indices is not None:
raise TypeError("if footprint_indices is specified, then var_name "
"may not contain a subscript")
assert isinstance(parsed_var_name.aggregate, Variable)
var_name = parsed_var_name.aggregate.name
sweep_indices = [parsed_var_name.index]
else:
raise ValueError("var_name must either be a variable name or a subscript")
# }}}
if rule_name is None:
rule_name = kernel.make_unique_var_name("%s_fetch" % var_name)
......@@ -448,19 +490,57 @@ def add_prefetch(kernel, var_name, sweep_dims=[], dim_arg_names=None,
kernel = extract_subst(kernel, rule_name, uni_template, parameters)
new_fetch_dims = []
for fd in sweep_dims:
for fd in sweep_inames:
if isinstance(fd, int):
new_fetch_dims.append(parameters[fd])
else:
new_fetch_dims.append(fd)
return precompute(kernel, rule_name, arg.dtype, sweep_dims,
footprint_generators = None
if sweep_indices is not None:
if not isinstance(sweep_indices, (list, tuple)):
sweep_indices = [sweep_indices]
def standardize_sweep_indices(si):
if isinstance(si, str):
from loopy.symbolic import parse
si = parse(si)
if not isinstance(si, tuple):
si = (si,)
if len(si) != arg.dimensions:
raise ValueError("sweep index '%s' has the wrong number of dimensions")
for subst_map in kernel.applied_substitutions:
from loopy.symbolic import SubstitutionMapper
from pymbolic.mapper.substitutor import make_subst_func
si = SubstitutionMapper(make_subst_func(subst_map))(si)
return si
sweep_indices = [standardize_sweep_indices(si) for si in sweep_indices]
from pymbolic.primitives import Variable
footprint_generators = [
Variable(var_name)(*si) for si in sweep_indices]
new_kernel = precompute(kernel, rule_name, arg.dtype, sweep_inames,
footprint_generators=footprint_generators,
new_storage_axis_names=dim_arg_names,
default_tag=default_tag)
# }}}
# If the rule survived past precompute() (i.e. some accesses fell outside
# the footprint), get rid of it before moving on.
if rule_name in new_kernel.substitutions:
return apply_subst(new_kernel, rule_name)
else:
return new_kernel
# }}}
......
......@@ -13,7 +13,7 @@ from pymbolic import var
class InvocationDescriptor(Record):
__slots__ = ["expr", "args", ]
__slots__ = ["expr", "args", "expands_footprint", "is_in_footprint"]
......@@ -38,70 +38,105 @@ def to_parameters_or_project_out(param_inames, set_inames, set):
# {{{ construct storage->sweep map
def construct_storage_to_sweep_map(invocation_descriptors, domain_dup_sweep,
def build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
storage_axis_names, storage_axis_sources, prime_sweep_inames):
# The storage map goes from storage axes to domain_dup_sweep.
# The first len(arg_names) storage dimensions are the rule's arguments.
map_space = domain_dup_sweep.get_space()
stor_dim = len(storage_axis_names)
rn = map_space.dim(dim_type.out)
result = None
map_space = map_space.add_dims(dim_type.in_, stor_dim)
for i, saxis in enumerate(storage_axis_names):
# arg names are initially primed, to be replaced with unprimed
# base-0 versions below
for invdesc in invocation_descriptors:
map_space = domain_dup_sweep.get_space()
stor_dim = len(storage_axis_names)
rn = map_space.dim(dim_type.out)
map_space = map_space.set_dim_name(dim_type.in_, i, saxis+"'")
map_space = map_space.add_dims(dim_type.in_, stor_dim)
for i, saxis in enumerate(storage_axis_names):
# arg names are initially primed, to be replaced with unprimed
# base-0 versions below
# map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep]
map_space = map_space.set_dim_name(dim_type.in_, i, saxis+"'")
set_space = map_space.move_dims(
dim_type.out, rn,
dim_type.in_, 0, stor_dim).range()
# map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep]
# set_space: [domain](dup_sweep_index)[dup_sweep][stor_axes']
set_space = map_space.move_dims(
dim_type.out, rn,
dim_type.in_, 0, stor_dim).range()
stor2sweep = None
# set_space: [domain](dup_sweep_index)[dup_sweep][stor_axes']
from loopy.symbolic import aff_from_expr
stor2sweep = None
for saxis, saxis_source in zip(storage_axis_names, storage_axis_sources):
if isinstance(saxis_source, int):
# an argument
cns = isl.Constraint.equality_from_aff(
aff_from_expr(set_space,
var(saxis+"'")
- prime_sweep_inames(invdesc.args[saxis_source])))
else:
# a 'bare' sweep iname
cns = isl.Constraint.equality_from_aff(
aff_from_expr(set_space,
var(saxis+"'")
- prime_sweep_inames(var(saxis_source))))
from loopy.symbolic import aff_from_expr
cns_map = isl.BasicMap.from_constraint(cns)
if stor2sweep is None:
stor2sweep = cns_map
else:
stor2sweep = stor2sweep.intersect(cns_map)
for saxis, saxis_source in zip(storage_axis_names, storage_axis_sources):
if isinstance(saxis_source, int):
# an argument
cns = isl.Constraint.equality_from_aff(
aff_from_expr(set_space,
var(saxis+"'")
- prime_sweep_inames(invdesc.args[saxis_source])))
else:
# a 'bare' sweep iname
cns = isl.Constraint.equality_from_aff(
aff_from_expr(set_space,
var(saxis+"'")
- prime_sweep_inames(var(saxis_source))))
cns_map = isl.BasicMap.from_constraint(cns)
if stor2sweep is None:
stor2sweep = cns_map
stor2sweep = stor2sweep.move_dims(
dim_type.in_, 0,
dim_type.out, rn, stor_dim)
# stor2sweep is back in map_space
return stor2sweep
def build_global_storage_to_sweep_map(invocation_descriptors, domain_dup_sweep,
storage_axis_names, storage_axis_sources, prime_sweep_inames):
"""
As a side effect, this fills out is_in_footprint in the
invocation descriptors.
"""
# The storage map goes from storage axes to domain_dup_sweep.
# The first len(arg_names) storage dimensions are the rule's arguments.
global_stor2sweep = None
# build footprint
for invdesc in invocation_descriptors:
if invdesc.expands_footprint:
stor2sweep = build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
storage_axis_names, storage_axis_sources, prime_sweep_inames)
if global_stor2sweep is None:
global_stor2sweep = stor2sweep
else:
stor2sweep = stor2sweep.intersect(cns_map)
global_stor2sweep = global_stor2sweep.union(stor2sweep)
invdesc.is_in_footprint = True
if isinstance(global_stor2sweep, isl.BasicMap):
global_stor2sweep = isl.Map.from_basic_map(stor2sweep)
global_stor2sweep = global_stor2sweep.intersect_range(domain_dup_sweep)
stor2sweep = stor2sweep.move_dims(
dim_type.in_, 0,
dim_type.out, rn, stor_dim)
# check if non-footprint-building invocation descriptors fall into footprint
for invdesc in invocation_descriptors:
stor2sweep = build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
storage_axis_names, storage_axis_sources, prime_sweep_inames)
if isinstance(stor2sweep, isl.BasicMap):
stor2sweep = isl.Map.from_basic_map(stor2sweep)
# stor2sweep is back in map_space
stor2sweep = stor2sweep.intersect_range(domain_dup_sweep)
if result is None:
result = stor2sweep
if not invdesc.expands_footprint:
invdesc.is_in_footprint = stor2sweep.is_subset(global_stor2sweep)
else:
result = result.union(stor2sweep)
assert stor2sweep.domain().is_subset(global_stor2sweep.domain())
return result
return global_stor2sweep
# }}}
......@@ -157,14 +192,10 @@ def get_access_info(kernel, subst_name,
# }}}
stor2sweep = construct_storage_to_sweep_map(
stor2sweep = build_global_storage_to_sweep_map(
invocation_descriptors, domain_dup_sweep,
storage_axis_names, storage_axis_sources, prime_sweep_inames)
if isinstance(stor2sweep, isl.BasicMap):
stor2sweep = isl.Map.from_basic_map(stor2sweep)
stor2sweep = stor2sweep.intersect_range(domain_dup_sweep)
storage_base_indices, storage_shape = compute_bounds(
kernel, subst_name, stor2sweep, sweep_inames,
storage_axis_names)
......@@ -186,7 +217,7 @@ def get_access_info(kernel, subst_name,
# }}}
# {{{ subtract off the base indices
# add the new, base-0 as new in dimensions
# add the new, base-0 indices as new in dimensions
sp = stor2sweep.get_space()
stor_idx = sp.dim(dim_type.out)
......@@ -251,6 +282,7 @@ def simplify_via_aff(expr):
def precompute(kernel, subst_name, dtype, sweep_axes=[],
footprint_generators=None,
storage_axes=None, new_storage_axis_names=None, storage_axis_to_tag={},
default_tag="l.auto"):
"""Precompute the expression described in the substitution rule *subst_name*
......@@ -258,8 +290,13 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
a list of *sweep_axes* (order irrelevant) and an ordered list of *storage_axes*
(whose order will describe the axis ordering of the temporary array).
This function will then examine all usage sites of the substitution rule and
determine what the storage footprint of that sweep is.
*subst_name* may contain a period (".") to filter out a subset of the
usage sites of the substitution rule. (Namely those usage sites that
use the same dotted name.)
This function will then examine the *footprint_generators* (or all usage
sites of the substitution rule if not specified) and determine what the
storage footprint of that sweep is.
The following cases can arise for each sweep axis:
......@@ -276,7 +313,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
the so-named formal argument at *all* usage sites.
:arg sweep_axes: A :class:`list` of inames and/or rule argument names to be swept.
:arg storage_dims: A :class:`list` of inames and/or rule argument names/indices to be used as storage axes.
:arg storage_axes: A :class:`list` of inames and/or rule argument names/indices to be used as storage axes.
If `storage_axes` is not specified, it defaults to the arrangement
`<direct sweep axes><arguments>` with the direct sweep axes being the
......@@ -312,12 +349,32 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
"exclusively of inames" % expr)
invocation_descriptors.append(
InvocationDescriptor(expr=expr, args=args))
InvocationDescriptor(expr=expr, args=args,
expands_footprint=footprint_generators is None))
return expr
from loopy.symbolic import SubstitutionCallbackMapper
scm = SubstitutionCallbackMapper([(subst_name, subst_instance)], gather_substs)
if footprint_generators:
for fpg in footprint_generators:
if isinstance(fpg, str):
from loopy.symbolic import parse
fpg = parse(fpg)
from pymbolic.primitives import Variable, Call
if isinstance(fpg, Variable):
args = ()
elif isinstance(fpg, Call):
args = fpg.parameters
else:
raise ValueError("footprint generator must "
"be substitution rule invocation")
invocation_descriptors.append(
InvocationDescriptor(expr=fpg, args=args,
expands_footprint=True))
# We need to work on the fully expanded form of an expression.
# To that end, instantiate a substitutor.
from loopy.symbolic import ParametrizedSubstitutor
......@@ -347,6 +404,9 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
sweep_inames = set()
for invdesc in invocation_descriptors:
if not invdesc.expands_footprint:
continue
for swaxis in sweep_axes:
if isinstance(swaxis, int):
sweep_inames.update(
......@@ -369,6 +429,9 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
usage_arg_deps = set()
for invdesc in invocation_descriptors:
if not invdesc.expands_footprint:
continue
for arg in invdesc.args:
usage_arg_deps.update(get_dependencies(arg))
......@@ -515,9 +578,27 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
# }}}
# {{{ substitute rule into expressions in kernel
# {{{ substitute rule into expressions in kernel (if within footprint)
left_unused_subst_rule_invocations = [False]
def do_substs(expr, name, instance, args, rec):
if instance != subst_instance:
left_unused_subst_rule_invocations[0] = True
return expr
found = False
for invdesc in invocation_descriptors:
if expr == invdesc.expr:
found = True
break
if not invdesc.is_in_footprint:
left_unused_subst_rule_invocations[0] = True
return expr
assert found, expr
if len(args) != len(subst.arguments):
raise ValueError("invocation of '%s' with too few arguments"
% name)
......@@ -543,23 +624,23 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
new_outer_expr = new_outer_expr[tuple(stor_subscript)]
return new_outer_expr
# can't nest, don't recurse
# can't possibly be nested, don't recurse
new_insns = [compute_insn]
sub_map = SubstitutionCallbackMapper([(subst_name, subst_instance)], do_substs)
sub_map = SubstitutionCallbackMapper([subst_name], do_substs)
for insn in kernel.instructions:
new_insn = insn.copy(expression=sub_map(insn.expression))
new_insns.append(new_insn)
# also catch uses of our rule in other substitution rules
new_substs = dict(
(s.name, s.copy(expression=sub_map(s.expression)))
for s in kernel.substitutions.itervalues()
for s in kernel.substitutions.itervalues())
# leave rule be if instance was specified
# (even if it might end up unused--FIXME)
if subst_instance is not None
or s.name != subst_name)
# If the subst above caught all uses of the subst rule, get rid of it.
if not left_unused_subst_rule_invocations[0]:
del new_substs[subst_name]
# }}}
......
......@@ -494,6 +494,9 @@ class LoopKernel(Record):
objects
:ivar lowest_priority_inames:
:ivar breakable_inames: these inames' loops may be broken up by the scheduler
:ivar applied_substitutions: A list of past substitution dictionaries that
were applied to the kernel. These are stored so that they may be repeated
on expressions the user specifies later.
:ivar cache_manager:
......@@ -510,7 +513,8 @@ class LoopKernel(Record):
temporary_variables={},
local_sizes={},
iname_to_tag={}, iname_to_tag_requests=None, substitutions={},
cache_manager=None, lowest_priority_inames=[], breakable_inames=set()):
cache_manager=None, lowest_priority_inames=[], breakable_inames=set(),
applied_substitutions=[]):
"""
:arg domain: a :class:`islpy.BasicSet`, or a string parseable to a basic set by the isl.
Example: "{[i,j]: 0<=i < 10 and 0<= j < 9}"
......@@ -573,7 +577,7 @@ class LoopKernel(Record):
# {{{ instruction parser
def parse_if_necessary(insn):
from pymbolic import parse
from loopy.symbolic import parse
if isinstance(insn, Instruction):
insns.append(insn)
......@@ -597,8 +601,7 @@ class LoopKernel(Record):
raise RuntimeError("insn parse error")
lhs = parse(groups["lhs"])
from loopy.symbolic import FunctionToPrimitiveMapper
rhs = FunctionToPrimitiveMapper()(parse(groups["rhs"]))
rhs = parse(groups["rhs"])
if insn_match is not None:
if groups["label"] is not None:
......@@ -707,7 +710,8 @@ class LoopKernel(Record):
substitutions=substitutions,
cache_manager=cache_manager,
lowest_priority_inames=lowest_priority_inames,
breakable_inames=breakable_inames)
breakable_inames=breakable_inames,
applied_substitutions=applied_substitutions)
def make_unique_instruction_id(self, insns=None, based_on="insn", extra_used_ids=set()):
if insns is None:
......
......@@ -117,7 +117,7 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase):
# }}}
# {{{ functions to primitives
# {{{ functions to primitives, parsing
class FunctionToPrimitiveMapper(IdentityMapper):
"""Looks for invocations of a function called 'cse' or 'reduce' and
......@@ -194,6 +194,10 @@ class FunctionToPrimitiveMapper(IdentityMapper):
return Reduction(operation, tuple(processed_inames), red_expr)
def parse(expr_str):
from pymbolic import parse
return FunctionToPrimitiveMapper()(parse(expr_str))
# }}}
# {{{ reduction loop splitter
......
from __future__ import division
import numpy as np
import pyopencl as cl
import loopy as lp
from pyopencl.tools import pytest_generate_tests_for_pyopencl \
as pytest_generate_tests
def test_nbody(ctx_factory):
dtype = np.float32
ctx = ctx_factory()
knl = lp.make_kernel(ctx.devices[0],
"[N] -> {[i,j,k]: 0<=i,j<N and 0<=k<3 }",
[
"axdist(k) := x[i,k]-x[j,k]",
"invdist := rsqrt(sum_float32(k, axdist(k)**2))",
"pot[i] = sum_float32(j, if(i != j, invdist, 0))",
],
[
lp.ArrayArg("x", dtype, shape="N,3", order="C"),
lp.ArrayArg("pot", dtype, shape="N", order="C"),
lp.ScalarArg("N", np.int32),
],
name="nbody", assumptions="N>=1")
seq_knl = knl
def variant_1(knl):
knl = lp.split_dimension(knl, "i", 256,
outer_tag="g.0", inner_tag="l.0",
slabs=(0,1))
knl = lp.split_dimension(knl, "j", 256, slabs=(0,1))
return knl, []
def variant_cpu(knl):
knl = lp.split_dimension(knl, "i", 1024,
outer_tag="g.0", slabs=(0,1))
return knl, []
def variant_gpu(knl):
knl = lp.split_dimension(knl, "i", 256,
outer_tag="g.0", inner_tag="l.0", slabs=(0,1))
knl = lp.split_dimension(knl, "j", 256, slabs=(0,1))
knl = lp.add_prefetch(knl, "x[i,k]", ["k"], default_tag=None)
knl = lp.add_prefetch(knl, "x[j,k]", ["j_inner", "k"])
return knl, ["j_outer", "j_inner"]
n = 100
for variant in [variant_gpu]:
variant_knl, loop_prio = variant(knl)
kernel_gen = lp.generate_loop_schedules(variant_knl,
loop_priority=loop_prio)
kernel_gen = lp.check_kernels(kernel_gen, dict(N=n))
lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen,
op_count=4*n**2*1e-9, op_label="GOps/s",
parameters={"N": n}, print_ref_code=True)
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from py.test.cmdline import main
main([__file__])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment