From cc5a489a026ccb3fac03d098bfbe0c70f1b80714 Mon Sep 17 00:00:00 2001 From: Tim Warburton <timwar@caam.rice.edu> Date: Tue, 25 Oct 2011 19:34:55 -0500 Subject: [PATCH] IndexVariableFinder: Add flag include_reduction_inames. --- loopy/__init__.py | 3 ++- loopy/kernel.py | 5 ++--- loopy/symbolic.py | 12 +++++++++++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/loopy/__init__.py b/loopy/__init__.py index 3c43339a4..a90e99ffd 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -230,7 +230,8 @@ def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=Non forced_iname_deps = [] from loopy.symbolic import IndexVariableFinder - dependencies = IndexVariableFinder()(expr.child) + dependencies = IndexVariableFinder( + include_reduction_inames=False)(expr.child) assert dependencies <= parent_inames diff --git a/loopy/kernel.py b/loopy/kernel.py index ffbae5657..f047f90cc 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -236,9 +236,8 @@ class Instruction(Record): @memoize_method def all_inames(self): from loopy.symbolic import IndexVariableFinder - index_vars = ( - IndexVariableFinder()(self.expression) - | IndexVariableFinder()(self.assignee)) + ivarf = IndexVariableFinder(include_reduction_inames=False) + index_vars = (ivarf(self.expression) | ivarf(self.assignee)) return index_vars | set(self.forced_iname_deps) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 4504daa7d..25667eaa1 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -485,6 +485,9 @@ class ReductionCallbackMapper(IdentityMapper): # {{{ index dependency finding class IndexVariableFinder(CombineMapper): + def __init__(self, include_reduction_inames): + self.include_reduction_inames = include_reduction_inames + def combine(self, values): import operator return reduce(operator.or_, values, set()) @@ -508,7 +511,14 @@ class IndexVariableFinder(CombineMapper): return result def map_reduction(self, expr): - return set(expr.inames) | self.rec(expr.expr) + result = self.rec(expr.expr) + if not set(expr.inames) <= result: + raise RuntimeError("reduction '%s' does not depend on " + "reduction inames" % expr) + if self.include_reduction_inames: + return result + else: + return result - set(expr.inames) # }}} -- GitLab