Skip to content
Snippets Groups Projects
Commit 1c732a8f authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Second half of previous transform jostling commit

parent bbfc3fa7
No related branches found
No related tags found
No related merge requests found
...@@ -26,10 +26,7 @@ THE SOFTWARE. ...@@ -26,10 +26,7 @@ THE SOFTWARE.
import six import six
from six.moves import range, zip from six.moves import range, zip
import islpy as isl from loopy.symbolic import (
from loopy.symbolic import (RuleAwareIdentityMapper, RuleAwareSubstitutionMapper,
SubstitutionRuleMappingContext,
TaggedVariable, Reduction, LinearSubscript, ) TaggedVariable, Reduction, LinearSubscript, )
from loopy.diagnostic import LoopyError, LoopyWarning from loopy.diagnostic import LoopyError, LoopyWarning
...@@ -89,6 +86,9 @@ from loopy.transform.padding import ( ...@@ -89,6 +86,9 @@ from loopy.transform.padding import (
split_array_dim, split_arg_axis, find_padding_multiple, split_array_dim, split_arg_axis, find_padding_multiple,
add_padding) add_padding)
from loopy.transform.ilp import realize_ilp
from loopy.transform.batch import to_batched
# }}} # }}}
from loopy.preprocess import (preprocess_kernel, realize_reduction, from loopy.preprocess import (preprocess_kernel, realize_reduction,
...@@ -149,6 +149,12 @@ __all__ = [ ...@@ -149,6 +149,12 @@ __all__ = [
"split_array_dim", "split_arg_axis", "find_padding_multiple", "split_array_dim", "split_arg_axis", "find_padding_multiple",
"add_padding", "add_padding",
"realize_ilp",
"to_batched",
"fix_parameters",
# }}} # }}}
"get_dot_dependency_graph", "get_dot_dependency_graph",
...@@ -176,7 +182,6 @@ __all__ = [ ...@@ -176,7 +182,6 @@ __all__ = [
# {{{ from this file # {{{ from this file
"fix_parameters",
"register_preamble_generators", "register_preamble_generators",
"register_symbol_manglers", "register_symbol_manglers",
"register_function_manglers", "register_function_manglers",
...@@ -184,8 +189,6 @@ __all__ = [ ...@@ -184,8 +189,6 @@ __all__ = [
"set_caching_enabled", "set_caching_enabled",
"CacheMode", "CacheMode",
"make_copy_kernel", "make_copy_kernel",
"to_batched",
"realize_ilp",
# }}} # }}}
] ]
...@@ -194,84 +197,6 @@ __all__ = [ ...@@ -194,84 +197,6 @@ __all__ = [
# }}} # }}}
# {{{ fix_parameter
def _fix_parameter(kernel, name, value):
def process_set(s):
var_dict = s.get_var_dict()
try:
dt, idx = var_dict[name]
except KeyError:
return s
value_aff = isl.Aff.zero_on_domain(s.space) + value
from loopy.isl_helpers import iname_rel_aff
name_equal_value_aff = iname_rel_aff(s.space, name, "==", value_aff)
s = (s
.add_constraint(
isl.Constraint.equality_from_aff(name_equal_value_aff))
.project_out(dt, idx, 1))
return s
new_domains = [process_set(dom) for dom in kernel.domains]
from pymbolic.mapper.substitutor import make_subst_func
subst_func = make_subst_func({name: value})
from loopy.symbolic import SubstitutionMapper, PartialEvaluationMapper
subst_map = SubstitutionMapper(subst_func)
ev_map = PartialEvaluationMapper()
def map_expr(expr):
return ev_map(subst_map(expr))
from loopy.kernel.array import ArrayBase
new_args = []
for arg in kernel.args:
if arg.name == name:
# remove from argument list
continue
if not isinstance(arg, ArrayBase):
new_args.append(arg)
else:
new_args.append(arg.map_exprs(map_expr))
new_temp_vars = {}
for tv in six.itervalues(kernel.temporary_variables):
new_temp_vars[tv.name] = tv.map_exprs(map_expr)
from loopy.context_matching import parse_stack_match
within = parse_stack_match(None)
rule_mapping_context = SubstitutionRuleMappingContext(
kernel.substitutions, kernel.get_var_name_generator())
esubst_map = RuleAwareSubstitutionMapper(
rule_mapping_context, subst_func, within=within)
return (
rule_mapping_context.finish_kernel(
esubst_map.map_kernel(kernel))
.copy(
domains=new_domains,
args=new_args,
temporary_variables=new_temp_vars,
assumptions=process_set(kernel.assumptions),
))
def fix_parameters(kernel, **value_dict):
for name, value in six.iteritems(value_dict):
kernel = _fix_parameter(kernel, name, value)
return kernel
# }}}
# {{{ set_options # {{{ set_options
def set_options(kernel, *args, **kwargs): def set_options(kernel, *args, **kwargs):
...@@ -427,131 +352,4 @@ def make_copy_kernel(new_dim_tags, old_dim_tags=None): ...@@ -427,131 +352,4 @@ def make_copy_kernel(new_dim_tags, old_dim_tags=None):
# }}} # }}}
# {{{ to_batched
class _BatchVariableChanger(RuleAwareIdentityMapper):
def __init__(self, rule_mapping_context, kernel, batch_varying_args,
batch_iname_expr):
super(_BatchVariableChanger, self).__init__(rule_mapping_context)
self.kernel = kernel
self.batch_varying_args = batch_varying_args
self.batch_iname_expr = batch_iname_expr
def needs_batch_subscript(self, name):
return (
name in self.kernel.temporary_variables
or
name in self.batch_varying_args)
def map_subscript(self, expr, expn_state):
if not self.needs_batch_subscript(expr.aggregate.name):
return super(_BatchVariableChanger, self).map_subscript(expr, expn_state)
idx = expr.index
if not isinstance(idx, tuple):
idx = (idx,)
return type(expr)(expr.aggregate, (self.batch_iname_expr,) + idx)
def map_variable(self, expr, expn_state):
if not self.needs_batch_subscript(expr.name):
return super(_BatchVariableChanger, self).map_variable(expr, expn_state)
return expr.aggregate[self.batch_iname_expr]
def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch"):
"""Takes in a kernel that carries out an operation and returns a kernel
that carries out a batch of these operations.
:arg nbatches: the number of batches. May be a constant non-negative
integer or a string, which will be added as an integer argument.
:arg batch_varying_args: a list of argument names that depend vary per-batch.
Each such variable will have a batch index added.
"""
from pymbolic import var
vng = knl.get_var_name_generator()
batch_iname = vng(batch_iname_prefix)
batch_iname_expr = var(batch_iname)
new_args = []
batch_dom_str = "{[%(iname)s]: 0 <= %(iname)s < %(nbatches)s}" % {
"iname": batch_iname,
"nbatches": nbatches,
}
if not isinstance(nbatches, int):
batch_dom_str = "[%s] -> " % nbatches + batch_dom_str
new_args.append(ValueArg(nbatches, dtype=knl.index_dtype))
nbatches_expr = var(nbatches)
else:
nbatches_expr = nbatches
batch_domain = isl.BasicSet(batch_dom_str)
new_domains = [batch_domain] + knl.domains
for arg in knl.args:
if arg.name in batch_varying_args:
if isinstance(arg, ValueArg):
arg = GlobalArg(arg.name, arg.dtype, shape=(nbatches_expr,),
dim_tags="c")
else:
arg = arg.copy(
shape=(nbatches_expr,) + arg.shape,
dim_tags=("c",) * (len(arg.shape) + 1))
new_args.append(arg)
new_temps = {}
for temp in six.itervalues(knl.temporary_variables):
new_temps[temp.name] = temp.copy(
shape=(nbatches_expr,) + temp.shape,
dim_tags=("c",) * (len(arg.shape) + 1))
knl = knl.copy(
domains=new_domains,
args=new_args,
temporary_variables=new_temps)
rule_mapping_context = SubstitutionRuleMappingContext(
knl.substitutions, vng)
bvc = _BatchVariableChanger(rule_mapping_context,
knl, batch_varying_args, batch_iname_expr)
return rule_mapping_context.finish_kernel(
bvc.map_kernel(knl))
# }}}
# {{{ realize_ilp
def realize_ilp(kernel, iname):
"""Instruction-level parallelism (as realized by the loopy iname
tag ``"ilp"``) provides the illusion that multiple concurrent
program instances execute in lockstep within a single instruction
stream.
To do so, storage that is private to each instruction stream needs to be
duplicated so that each program instance receives its own copy. Storage
that is written to in an instruction using an ILP iname but whose left-hand
side indices do not contain said ILP iname is marked for duplication.
This storage duplication is carried out automatically at code generation
time, but, using this function, can also be carried out ahead of time
on a per-iname basis (so that, for instance, data layout of the duplicated
storage can be controlled explicitly.
"""
from loopy.ilp import add_axes_to_temporaries_for_ilp_and_vec
return add_axes_to_temporaries_for_ilp_and_vec(kernel, iname)
# }}}
# vim: foldmethod=marker # vim: foldmethod=marker
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment