diff --git a/loopy/match.py b/loopy/match.py index ab0038af8dc5e9189a382bb76115998f57aef74e..ab5f11bb2450b9f8b1c0e87ff0b1400a49a691c4 100644 --- a/loopy/match.py +++ b/loopy/match.py @@ -273,7 +273,7 @@ def parse_match(expr): """Syntax examples:: * ``id:yoink and writes:a_temp`` - * ``id:yoink and (not writes:a_temp or tagged:input)`` + * ``id:yoink and (not writes:a_temp or tag:input)`` """ if not expr: return All() diff --git a/loopy/transform/batch.py b/loopy/transform/batch.py index e7a86300f9d040cba1688e5bb0f3dcbbd926f783..7e6b03581e39d03bc06d2f6d37f65a1d4ac6a386 100644 --- a/loopy/transform/batch.py +++ b/loopy/transform/batch.py @@ -38,6 +38,20 @@ __doc__ = """ # {{{ to_batched +def temp_needs_batching_if_not_sequential(tv, batch_varying_args): + from loopy.kernel.data import temp_var_scope + if tv.name in batch_varying_args: + return True + if tv.initializer is not None and tv.read_only: + # do not batch read_only temps if not in + # `batch_varying_args` + return False + if tv.scope == temp_var_scope.PRIVATE: + # do not batch private temps if not in `batch_varying args` + return False + return True + + class _BatchVariableChanger(RuleAwareIdentityMapper): def __init__(self, rule_mapping_context, kernel, batch_varying_args, batch_iname_expr, sequential): @@ -50,14 +64,17 @@ class _BatchVariableChanger(RuleAwareIdentityMapper): def needs_batch_subscript(self, name): tv = self.kernel.temporary_variables.get(name) - return ( - (not self.sequential - and (tv is not None - and not ( - tv.initializer is not None - and tv.read_only))) - or - name in self.batch_varying_args) + + if name in self.batch_varying_args: + return True + if not self.sequential: + if tv is None: + return False + if not temp_needs_batching_if_not_sequential(tv, + self.batch_varying_args): + return False + + return True def map_subscript(self, expr, expn_state): if not self.needs_batch_subscript(expr.aggregate.name): @@ -89,6 +106,10 @@ def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch", sequential=False): """Takes in a kernel that carries out an operation and returns a kernel that carries out a batch of these operations. + .. note:: + For temporaries in a kernel that are private or read only + globals and if `sequential=True`, loopy does not does not batch these + variables unless explicitly mentioned in `batch_varying_args`. :arg nbatches: the number of batches. May be a constant non-negative integer or a string, which will be added as an integer argument. @@ -144,13 +165,13 @@ def to_batched(knl, nbatches, batch_varying_args, batch_iname_prefix="ibatch", new_temps = {} for temp in six.itervalues(knl.temporary_variables): - if temp.initializer is not None and temp.read_only: - new_temps[temp.name] = temp - else: + if temp_needs_batching_if_not_sequential(temp, batch_varying_args): new_temps[temp.name] = temp.copy( shape=(nbatches_expr,) + temp.shape, dim_tags=("c",) * (len(temp.shape) + 1), dim_names=_add_unique_dim_name("ibatch", temp.dim_names)) + else: + new_temps[temp.name] = temp knl = knl.copy(temporary_variables=new_temps) else: diff --git a/test/test_transform.py b/test/test_transform.py index e50605b46672f8e9c1817431f1577742b1f6fb4c..0e10db362f36b7fc258059c2ec7ed1a344b97212 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -96,13 +96,65 @@ def test_to_batched(ctx_factory): knl = lp.make_kernel( ''' { [i,j]: 0<=i,j