diff --git a/loopy/match.py b/loopy/match.py index a417c2799813fa13df013cefbb3b6c2227687dff..3c047e463939cd67a4878d202a754c0cab48058d 100644 --- a/loopy/match.py +++ b/loopy/match.py @@ -276,7 +276,7 @@ def parse_match(expr): """Syntax examples:: * ``id:yoink and writes:a_temp`` - * ``id:yoink and (not writes:a_temp or tagged:input)`` + * ``id:yoink and (not writes:a_temp or tag:input)`` """ if not expr: return All() diff --git a/loopy/transform/batch.py b/loopy/transform/batch.py index e7a86300f9d040cba1688e5bb0f3dcbbd926f783..7e6b03581e39d03bc06d2f6d37f65a1d4ac6a386 100644 --- a/loopy/transform/batch.py +++ b/loopy/transform/batch.py @@ -38,6 +38,20 @@ __doc__ = """ # {{{ to_batched +def temp_needs_batching_if_not_sequential(tv, batch_varying_args): + from loopy.kernel.data import temp_var_scope + if tv.name in batch_varying_args: + return True + if tv.initializer is not None and tv.read_only: + # do not batch read_only temps if not in + # `batch_varying_args` + return False + if tv.scope == temp_var_scope.PRIVATE: + # do not batch private temps if not in `batch_varying args` + return False + return True + + class _BatchVariableChanger(RuleAwareIdentityMapper): def __init__(self, rule_mapping_context, kernel, batch_varying_args, batch_iname_expr, sequential): @@ -50,14 +64,17 @@ class _BatchVariableChanger(RuleAwareIdentityMapper): def needs_batch_subscript(self, name): tv = self.kernel.temporary_variables.get(name) - return ( - (not self.sequential - and (tv is not None - and not ( - tv.initializer is not None - and tv.read_only))) - or - name in self.batch_varying_args) + + if name in self.batch_varying_args: + return True + if not self.sequential: + if tv is None: + return False + if not temp_needs_batching_if_not_sequential(tv, + self.batch_varying_args): + return False + + return True def map_subscript(self, expr, expn_state): if not self.needs_batch_subscript(expr.aggregate.name): @@ -89,6 +106,10 @@ def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch", sequential=False): """Takes in a kernel that carries out an operation and returns a kernel that carries out a batch of these operations. + .. note:: + For temporaries in a kernel that are private or read only + globals and if `sequential=True`, loopy does not does not batch these + variables unless explicitly mentioned in `batch_varying_args`. :arg nbatches: the number of batches. May be a constant non-negative integer or a string, which will be added as an integer argument. @@ -144,13 +165,13 @@ def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch", new_temps = {} for temp in six.itervalues(knl.temporary_variables): - if temp.initializer is not None and temp.read_only: - new_temps[temp.name] = temp - else: + if temp_needs_batching_if_not_sequential(temp, batch_varying_args): new_temps[temp.name] = temp.copy( shape=(nbatches_expr,) + temp.shape, dim_tags=("c",) * (len(temp.shape) + 1), dim_names=_add_unique_dim_name("ibatch", temp.dim_names)) + else: + new_temps[temp.name] = temp knl = knl.copy(temporary_variables=new_temps) else: diff --git a/test/test_transform.py b/test/test_transform.py index e50605b46672f8e9c1817431f1577742b1f6fb4c..0e10db362f36b7fc258059c2ec7ed1a344b97212 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -96,13 +96,65 @@ def test_to_batched(ctx_factory): knl = lp.make_kernel( ''' { [i,j]: 0<=i,j<n } ''', ''' out[i] = sum(j, a[i,j]*x[j])''') + knl = lp.add_and_infer_dtypes(knl, dict(out=np.float32, + x=np.float32, + a=np.float32)) bknl = lp.to_batched(knl, "nbatches", "out,x") + ref_knl = lp.make_kernel( + ''' { [i,j,k]: 0<=i,j<n and 0<=k<nbatches} ''', + '''out[k, i] = sum(j, a[i,j]*x[k, j])''') + ref_knl = lp.add_and_infer_dtypes(ref_knl, dict(out=np.float32, + x=np.float32, + a=np.float32)) + + a = np.random.randn(5, 5).astype(np.float32) + x = np.random.randn(7, 5).astype(np.float32) + + # Running both the kernels + evt, (out1, ) = bknl(queue, a=a, x=x, n=5, nbatches=7) + evt, (out2, ) = ref_knl(queue, a=a, x=x, n=5, nbatches=7) + + # checking that the outputs are same + assert np.linalg.norm(out1-out2) < 1e-15 + + +def test_to_batched_temp(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + ''' { [i,j]: 0<=i,j<n } ''', + ''' cnst = 2.0 + out[i] = sum(j, cnst*a[i,j]*x[j])''', + [lp.TemporaryVariable( + "cnst", + dtype=np.float32, + shape=(), + scope=lp.temp_var_scope.PRIVATE), '...']) + knl = lp.add_and_infer_dtypes(knl, dict(out=np.float32, + x=np.float32, + a=np.float32)) + ref_knl = lp.make_kernel( + ''' { [i,j]: 0<=i,j<n } ''', + '''out[i] = sum(j, 2.0*a[i,j]*x[j])''') + ref_knl = lp.add_and_infer_dtypes(ref_knl, dict(out=np.float32, + x=np.float32, + a=np.float32)) + + bknl = lp.to_batched(knl, "nbatches", "out,x") + bref_knl = lp.to_batched(ref_knl, "nbatches", "out,x") + + # checking that cnst is not being bathced + assert bknl.temporary_variables['cnst'].shape == () + a = np.random.randn(5, 5) x = np.random.randn(7, 5) - bknl(queue, a=a, x=x) + # Checking that the program compiles and the logic is correct + lp.auto_test_vs_ref( + bref_knl, ctx, bknl, + parameters=dict(a=a, x=x, n=5, nbatches=7)) def test_add_barrier(ctx_factory):