Skip to content
Snippets Groups Projects
Commit 3feea4e8 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add Collector, make DependencyMapper use it

parent 04574f89
No related branches found
No related tags found
No related merge requests found
......@@ -65,6 +65,8 @@ Base classes for new mappers
.. autoclass:: CombineMapper
.. autoclass:: Collector
.. autoclass:: IdentityMapper
.. autoclass:: WalkMapper
......@@ -297,6 +299,29 @@ class CombineMapper(RecursiveMapper):
# }}}
# {{{ collector
class Collector(CombineMapper):
"""A subclass of :class:`CombineMapper` for the common purpose of
collecting data derived from an expression in a set that gets 'unioned'
across children at each non-leaf node in the expression tree.
By default, nothing is collected. All leaves return empty sets.
"""
def combine(self, values):
import operator
return reduce(operator.or_, values, set())
def map_constant(self, expr):
return set()
map_variable = map_constant
map_function_symbol = map_constant
# }}}
# {{{ identity mapper
class IdentityMapper(Mapper):
......
......@@ -22,10 +22,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from pymbolic.mapper import CombineMapper, CSECachingMapperMixin
from pymbolic.mapper import Collector, CSECachingMapperMixin
class DependencyMapper(CSECachingMapperMixin, CombineMapper):
class DependencyMapper(CSECachingMapperMixin, Collector):
"""Maps an expression to the :class:`set` of expressions it
is based on. The ``include_*`` arguments to the constructor
determine which types of objects occur in this output set.
......@@ -61,19 +61,9 @@ class DependencyMapper(CSECachingMapperMixin, CombineMapper):
self.include_cses = include_cses
def combine(self, values):
import operator
return reduce(operator.or_, values, set())
def map_constant(self, expr):
return set()
def map_variable(self, expr):
return set([expr])
def map_function_symbol(self, expr):
return set()
def map_call(self, expr):
if self.include_calls == "descend_args":
return self.combine(
......@@ -81,7 +71,7 @@ class DependencyMapper(CSECachingMapperMixin, CombineMapper):
elif self.include_calls:
return set([expr])
else:
return CombineMapper.map_call(self, expr)
return super(DependencyMapper, self).map_call(expr)
def map_call_with_kwargs(self, expr):
if self.include_calls == "descend_args":
......@@ -92,25 +82,25 @@ class DependencyMapper(CSECachingMapperMixin, CombineMapper):
elif self.include_calls:
return set([expr])
else:
return CombineMapper.map_call_with_kwargs(self, expr)
return super(DependencyMapper, self).map_call_with_kwargs(expr)
def map_lookup(self, expr):
if self.include_lookups:
return set([expr])
else:
return CombineMapper.map_lookup(self, expr)
return super(DependencyMapper, self).map_lookup(expr)
def map_subscript(self, expr):
if self.include_subscripts:
return set([expr])
else:
return CombineMapper.map_subscript(self, expr)
return super(DependencyMapper, self).map_subscript(expr)
def map_common_subexpression_uncached(self, expr):
if self.include_cses:
return set([expr])
else:
return CombineMapper.map_common_subexpression(self, expr)
return super(DependencyMapper, self).map_common_subexpression(expr)
def map_slice(self, expr):
return self.combine(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment