diff --git a/doc/mappers.rst b/doc/mappers.rst index 93d39b42183a6e505db18a4aa667f9fd8a1cd7d9..4a92531f903fe5732e96b0a01d3cfc83e93f9b1b 100644 --- a/doc/mappers.rst +++ b/doc/mappers.rst @@ -39,4 +39,10 @@ Finding expression properties .. autoclass:: FlopCounter .. autoclass:: CSEAwareFlopCounter +Analysis tools +^^^^^^^^^^^^^^ + +.. automodule:: pymbolic.mapper.analysis + + .. vim: sw=4 diff --git a/pymbolic/mapper/analysis.py b/pymbolic/mapper/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..633c0340ca54294d3c4dd8fb78307f520b53dd98 --- /dev/null +++ b/pymbolic/mapper/analysis.py @@ -0,0 +1,64 @@ +__copyright__ = """Copyright (C) 2022 University of Illinois Board of Trustees""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +from pymbolic.mapper import CachedWalkMapper + + +__doc__ = """ +.. autoclass:: NodeCountMapper +.. autofunction:: get_num_nodes +""" + + +# {{{ NodeCountMapper + +class NodeCountMapper(CachedWalkMapper): + """ + Counts the number of nodes in an expression tree. Nodes that occur + repeatedly as well as :class:`~pymbolic.primitives.CommonSubexpression` + nodes are only counted once. + + .. attribute:: count + + The number of nodes. + """ + + def __init__(self) -> None: + super().__init__() + self.count = 0 + + def post_visit(self, expr) -> None: + self.count += 1 + + +def get_num_nodes(expr) -> int: + """Returns the number of nodes in *expr*. Nodes that occur + repeatedly as well as :class:`~pymbolic.primitives.CommonSubexpression` + nodes are only counted once.""" + + ncm = NodeCountMapper() + ncm(expr) + + return ncm.count + +# }}} diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 08a9d81cd24fe52dbc147ec561d2ca4e952e7510..d30dbdbb748a478e0ba6faee6586667fd2f510f1 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -938,6 +938,23 @@ def test_cached_mapper_differentiates_float_int(): # }}} +def test_nodecount(): + from pymbolic.mapper.analysis import get_num_nodes + expr = prim.Sum((4, 4.0)) + + assert get_num_nodes(expr) == 3 + + x = prim.Variable("x") + y = prim.Variable("y") + z = prim.Variable("z") + + subexpr = prim.CommonSubexpression(4 * (x**2 + y + z)) + expr = 3*subexpr + subexpr + subexpr + subexpr + expr = expr + expr + expr + + assert get_num_nodes(expr) == 12 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: