diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 00a92bdd0b69d0706ca44ca2922842fa88489c1e..d267f83bcc8687d13a5f11be6acb6f80bd50d824 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -145,6 +145,8 @@ Helper functions .. autofunction:: is_zero .. autofunction:: is_constant +.. autofunction:: flattened_sum +.. autofunction:: flattened_product .. autofunction:: register_constant_class .. autofunction:: unregister_constant_class .. autofunction:: variables @@ -159,7 +161,6 @@ vectors and matrices of :mod:`pymbolic` objects. .. autofunction:: make_sym_vector .. autofunction:: make_sym_array - Constants --------- @@ -1589,9 +1590,14 @@ def subscript(expression, index): return Subscript(expression, index) -def flattened_sum(components): - # flatten any potential sub-sums - queue = list(components) +def flattened_sum(terms): + r"""Recursively flattens all the top level :class:`Sum`\ s in *terms*. + + :arg terms: an :class:`~collections.abc.Iterable` of expressions. + :returns: a :class:`Sum` expression or, if there is only one term in + the sum, the respective term. + """ + queue = list(terms) done = [] while queue: @@ -1619,9 +1625,17 @@ def linear_combination(coefficients, expressions): if coefficient and expression) -def flattened_product(components): - # flatten any potential sub-products - queue = list(components) +def flattened_product(terms): + r"""Recursively flattens all the top level :class:`Product`\ s in *terms*. + + This operation does not change the order of the terms in the products, so + it does not require the product to be commutative. + + :arg terms: an :class:`~collections.abc.Iterable` of expressions. + :returns: a :class:`Product` expression or, if there is only one term in + the product, the respective term. + """ + queue = list(terms) done = [] while queue: @@ -1629,7 +1643,7 @@ def flattened_product(components): if is_zero(item): return 0 - if is_zero(item-1): + if is_zero(item - 1): continue if isinstance(item, Product):