Skip to content
Snippets Groups Projects
Commit 8f626245 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add, test to_batched

parent 11d87fcc
No related branches found
No related tags found
No related merge requests found
...@@ -472,6 +472,11 @@ Arguments ...@@ -472,6 +472,11 @@ Arguments
.. autofunction:: add_and_infer_dtypes .. autofunction:: add_and_infer_dtypes
Batching
^^^^^^^^
.. autofunction:: to_batched
Finishing up Finishing up
^^^^^^^^^^^^ ^^^^^^^^^^^^
......
...@@ -1984,4 +1984,107 @@ def alias_temporaries(knl, names, base_name_prefix=None): ...@@ -1984,4 +1984,107 @@ def alias_temporaries(knl, names, base_name_prefix=None):
# }}} # }}}
# {{{ to_batched
class _BatchVariableChanger(RuleAwareIdentityMapper):
def __init__(self, rule_mapping_context, kernel, batch_varying_args,
batch_iname_expr):
super(_BatchVariableChanger, self).__init__(rule_mapping_context)
self.kernel = kernel
self.batch_varying_args = batch_varying_args
self.batch_iname_expr = batch_iname_expr
def needs_batch_subscript(self, name):
return (
name in self.kernel.temporary_variables
or
name in self.batch_varying_args)
def map_subscript(self, expr, expn_state):
if not self.needs_batch_subscript(expr.aggregate.name):
return super(_BatchVariableChanger, self).map_subscript(expr, expn_state)
idx = expr.index
if not isinstance(idx, tuple):
idx = (idx,)
return type(expr)(expr.aggregate, (self.batch_iname_expr,) + idx)
def map_variable(self, expr, expn_state):
if not self.needs_batch_subscript(expr.name):
return super(_BatchVariableChanger, self).map_variable(expr, expn_state)
return expr.aggregate[self.batch_iname_expr]
def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch"):
"""Takes in a kernel that carries out an operation and returns a kernel
that carries out a batch of these operations.
:arg nbatches: the number of batches. May be a constant non-negative
integer or a string, which will be added as an integer argument.
:arg batch_varying_args: a list of argument names that depend vary per-batch.
Each such variable will have a batch index added.
"""
from pymbolic import var
vng = knl.get_var_name_generator()
batch_iname = vng(batch_iname_prefix)
batch_iname_expr = var(batch_iname)
new_args = []
batch_dom_str = "{[%(iname)s]: 0 <= %(iname)s < %(nbatches)s}" % {
"iname": batch_iname,
"nbatches": nbatches,
}
if not isinstance(nbatches, int):
batch_dom_str = "[%s] -> " % nbatches + batch_dom_str
new_args.append(ValueArg(nbatches, dtype=knl.index_dtype))
nbatches_expr = var(nbatches)
else:
nbatches_expr = nbatches
batch_domain = isl.BasicSet(batch_dom_str)
new_domains = [batch_domain] + knl.domains
for arg in knl.args:
if arg.name in batch_varying_args:
if isinstance(arg, ValueArg):
arg = GlobalArg(arg.name, arg.dtype, shape=(nbatches_expr,),
dim_tags="c")
else:
arg = arg.copy(
shape=(nbatches_expr,) + arg.shape,
dim_tags=("c",) * (len(arg.shape) + 1))
new_args.append(arg)
new_temps = {}
for temp in six.itervalues(knl.temporary_variables):
new_temps[temp.name] = temp.copy(
shape=(nbatches_expr,) + temp.shape,
dim_tags=("c",) * (len(arg.shape) + 1))
knl = knl.copy(
domains=new_domains,
args=new_args,
temporary_variables=new_temps)
rule_mapping_context = SubstitutionRuleMappingContext(
knl.substitutions, vng)
bvc = _BatchVariableChanger(rule_mapping_context,
knl, batch_varying_args, batch_iname_expr)
return rule_mapping_context.finish_kernel(
bvc.map_kernel(knl))
# }}}
# vim: foldmethod=marker # vim: foldmethod=marker
...@@ -2158,6 +2158,22 @@ def test_sci_notation_literal(ctx_factory): ...@@ -2158,6 +2158,22 @@ def test_sci_notation_literal(ctx_factory):
assert (np.abs(out.get() - 1e-12) < 1e-20).all() assert (np.abs(out.get() - 1e-12) < 1e-20).all()
def test_to_batched(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.make_kernel(
''' { [i,j]: 0<=i,j<n } ''',
''' out[i] = sum(j, a[i,j]*x[j])''')
bknl = lp.to_batched(knl, "nbatches", "out,x")
a = np.random.randn(5, 5)
x = np.random.randn(7, 5)
bknl(queue, a=a, x=x)
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) > 1: if len(sys.argv) > 1:
exec(sys.argv[1]) exec(sys.argv[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment