diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index e0c0d23028d09c5bcd485f4c5181a3d2e4265a9d..6d0a1591fb093e2c907fdc63fc59ed4c364fbd68 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -65,6 +65,8 @@ Base classes for new mappers .. autoclass:: CombineMapper +.. autoclass:: Collector + .. autoclass:: IdentityMapper .. autoclass:: WalkMapper @@ -297,6 +299,29 @@ class CombineMapper(RecursiveMapper): # }}} +# {{{ collector + +class Collector(CombineMapper): + """A subclass of :class:`CombineMapper` for the common purpose of + collecting data derived from an expression in a set that gets 'unioned' + across children at each non-leaf node in the expression tree. + + By default, nothing is collected. All leaves return empty sets. + """ + + def combine(self, values): + import operator + return reduce(operator.or_, values, set()) + + def map_constant(self, expr): + return set() + + map_variable = map_constant + map_function_symbol = map_constant + +# }}} + + # {{{ identity mapper class IdentityMapper(Mapper): diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index 36d061bde227bdcbf88997267761538975bd96c7..242b156393bec72d5f88298627bb28651eee6e03 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -22,10 +22,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from pymbolic.mapper import CombineMapper, CSECachingMapperMixin +from pymbolic.mapper import Collector, CSECachingMapperMixin -class DependencyMapper(CSECachingMapperMixin, CombineMapper): +class DependencyMapper(CSECachingMapperMixin, Collector): """Maps an expression to the :class:`set` of expressions it is based on. The ``include_*`` arguments to the constructor determine which types of objects occur in this output set. @@ -61,19 +61,9 @@ class DependencyMapper(CSECachingMapperMixin, CombineMapper): self.include_cses = include_cses - def combine(self, values): - import operator - return reduce(operator.or_, values, set()) - - def map_constant(self, expr): - return set() - def map_variable(self, expr): return set([expr]) - def map_function_symbol(self, expr): - return set() - def map_call(self, expr): if self.include_calls == "descend_args": return self.combine( @@ -81,7 +71,7 @@ class DependencyMapper(CSECachingMapperMixin, CombineMapper): elif self.include_calls: return set([expr]) else: - return CombineMapper.map_call(self, expr) + return super(DependencyMapper, self).map_call(expr) def map_call_with_kwargs(self, expr): if self.include_calls == "descend_args": @@ -92,25 +82,25 @@ class DependencyMapper(CSECachingMapperMixin, CombineMapper): elif self.include_calls: return set([expr]) else: - return CombineMapper.map_call_with_kwargs(self, expr) + return super(DependencyMapper, self).map_call_with_kwargs(expr) def map_lookup(self, expr): if self.include_lookups: return set([expr]) else: - return CombineMapper.map_lookup(self, expr) + return super(DependencyMapper, self).map_lookup(expr) def map_subscript(self, expr): if self.include_subscripts: return set([expr]) else: - return CombineMapper.map_subscript(self, expr) + return super(DependencyMapper, self).map_subscript(expr) def map_common_subexpression_uncached(self, expr): if self.include_cses: return set([expr]) else: - return CombineMapper.map_common_subexpression(self, expr) + return super(DependencyMapper, self).map_common_subexpression(expr) def map_slice(self, expr): return self.combine(