diff --git a/doc/reference.rst b/doc/reference.rst index e04d0fa2e1a895d385a67da195d7544e97f2f69a..1c9dda7b5db4831534809ff59ef76216cd21b5cc 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -472,6 +472,11 @@ Arguments .. autofunction:: add_and_infer_dtypes +Batching +^^^^^^^^ + +.. autofunction:: to_batched + Finishing up ^^^^^^^^^^^^ diff --git a/loopy/__init__.py b/loopy/__init__.py index e8dd4e54a6e3b7ba9701e300c027256a5524edbf..d24f507c2fb1e87cca3e9cd7b8dd67778bbf6253 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1984,4 +1984,107 @@ def alias_temporaries(knl, names, base_name_prefix=None): # }}} + +# {{{ to_batched + +class _BatchVariableChanger(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, kernel, batch_varying_args, + batch_iname_expr): + super(_BatchVariableChanger, self).__init__(rule_mapping_context) + + self.kernel = kernel + self.batch_varying_args = batch_varying_args + self.batch_iname_expr = batch_iname_expr + + def needs_batch_subscript(self, name): + return ( + name in self.kernel.temporary_variables + or + name in self.batch_varying_args) + + def map_subscript(self, expr, expn_state): + if not self.needs_batch_subscript(expr.aggregate.name): + return super(_BatchVariableChanger, self).map_subscript(expr, expn_state) + + idx = expr.index + if not isinstance(idx, tuple): + idx = (idx,) + + return type(expr)(expr.aggregate, (self.batch_iname_expr,) + idx) + + def map_variable(self, expr, expn_state): + if not self.needs_batch_subscript(expr.name): + return super(_BatchVariableChanger, self).map_variable(expr, expn_state) + + return expr.aggregate[self.batch_iname_expr] + + +def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch"): + """Takes in a kernel that carries out an operation and returns a kernel + that carries out a batch of these operations. + + :arg nbatches: the number of batches. May be a constant non-negative + integer or a string, which will be added as an integer argument. + :arg batch_varying_args: a list of argument names that depend vary per-batch. + Each such variable will have a batch index added. + """ + + from pymbolic import var + + vng = knl.get_var_name_generator() + batch_iname = vng(batch_iname_prefix) + batch_iname_expr = var(batch_iname) + + new_args = [] + + batch_dom_str = "{[%(iname)s]: 0 <= %(iname)s < %(nbatches)s}" % { + "iname": batch_iname, + "nbatches": nbatches, + } + + if not isinstance(nbatches, int): + batch_dom_str = "[%s] -> " % nbatches + batch_dom_str + new_args.append(ValueArg(nbatches, dtype=knl.index_dtype)) + + nbatches_expr = var(nbatches) + else: + nbatches_expr = nbatches + + batch_domain = isl.BasicSet(batch_dom_str) + new_domains = [batch_domain] + knl.domains + + for arg in knl.args: + if arg.name in batch_varying_args: + if isinstance(arg, ValueArg): + arg = GlobalArg(arg.name, arg.dtype, shape=(nbatches_expr,), + dim_tags="c") + else: + arg = arg.copy( + shape=(nbatches_expr,) + arg.shape, + dim_tags=("c",) * (len(arg.shape) + 1)) + + new_args.append(arg) + + new_temps = {} + + for temp in six.itervalues(knl.temporary_variables): + new_temps[temp.name] = temp.copy( + shape=(nbatches_expr,) + temp.shape, + dim_tags=("c",) * (len(arg.shape) + 1)) + + knl = knl.copy( + domains=new_domains, + args=new_args, + temporary_variables=new_temps) + + rule_mapping_context = SubstitutionRuleMappingContext( + knl.substitutions, vng) + bvc = _BatchVariableChanger(rule_mapping_context, + knl, batch_varying_args, batch_iname_expr) + return rule_mapping_context.finish_kernel( + bvc.map_kernel(knl)) + + +# }}} + # vim: foldmethod=marker diff --git a/test/test_loopy.py b/test/test_loopy.py index 4b2dcc5f553d51fe8bad2a3a401f80135242c9e8..7cad3504859d199c0581c8d3248ebafe50a34c4a 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2158,6 +2158,22 @@ def test_sci_notation_literal(ctx_factory): assert (np.abs(out.get() - 1e-12) < 1e-20).all() +def test_to_batched(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + knl = lp.make_kernel( + ''' { [i,j]: 0<=i,j<n } ''', + ''' out[i] = sum(j, a[i,j]*x[j])''') + + bknl = lp.to_batched(knl, "nbatches", "out,x") + + a = np.random.randn(5, 5) + x = np.random.randn(7, 5) + + bknl(queue, a=a, x=x) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])