Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • isuruf/pymbolic
  • inducer/pymbolic
  • xywei/pymbolic
  • wence-/pymbolic
  • kaushikcfd/pymbolic
  • fikl2/pymbolic
  • zweiner2/pymbolic
7 results
Show changes
Showing
with 2307 additions and 986 deletions
#! /bin/sh
rsync --progress --verbose --archive --delete _build/html/* doc-upload:doc/pymbolic
rsync --verbose --archive --delete _build/html/ doc-upload:doc/pymbolic
......@@ -8,7 +8,8 @@ Parser
.. function:: parse(expr_str)
Return a :class:`pymbolic.primitives.Expression` tree corresponding to *expr_str*.
Return a :class:`pymbolic.primitives.ExpressionNode` tree corresponding
to *expr_str*.
The parser is also relatively easy to extend. See the source code of the following
class.
......@@ -43,3 +44,14 @@ Interoperability with Python's :mod:`ast` module
------------------------------------------------
.. automodule:: pymbolic.interop.ast
Interoperability with :mod:`matchpy.functions` module
-----------------------------------------------------
.. automodule:: pymbolic.interop.matchpy
Visualizing Expressions
=======================
.. autofunction:: pymbolic.imperative.utils.get_dot_dependency_graph
# See https://github.com/inducer/pymbolic/pull/110 for context
import sys
from pymbolic import parse
from pymbolic.mapper import CachedIdentityMapper
from pymbolic.mapper.optimize import optimize_mapper
from pymbolic.primitives import Variable
code = ("(-1)*((cse_577[_pt_data_48[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
"_pt_data_49[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_48[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_577[_pt_data_46[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_47[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_46[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_577[_pt_data_7[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_43[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_7[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_577[_pt_data_44[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_45[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_44[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_579[_pt_data_68[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_69[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_68[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_579[_pt_data_66[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_67[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_66[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_579[_pt_data_50[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_63[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_50[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_579[_pt_data_64[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_65[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_64[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_581[_pt_data_88[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_89[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_88[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_581[_pt_data_86[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_87[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_86[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_581[_pt_data_70[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_83[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_70[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_581[_pt_data_84[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_85[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_84[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_582[_pt_data_107[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_108[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_107[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_582[_pt_data_105[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_106[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_105[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_582[_pt_data_90[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_102[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_90[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_582[_pt_data_103[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_104[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_103[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0))"
" + (cse_572[_pt_data_48[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_49[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_48[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0) "
"+ (cse_572[_pt_data_46[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_47[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_46[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0) "
"+ (cse_572[_pt_data_7[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_43[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_7[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_572[_pt_data_44[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_45[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_44[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_573[_pt_data_68[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_69[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_68[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_573[_pt_data_66[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_67[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_66[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_573[_pt_data_50[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_63[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_50[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_573[_pt_data_64[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_65[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_64[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_574[_pt_data_88[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_89[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_88[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_574[_pt_data_86[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_87[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_86[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0) "
"+ (cse_574[_pt_data_70[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_83[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_70[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_574[_pt_data_84[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_85[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_84[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_575[_pt_data_107[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_108[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_107[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_575[_pt_data_105[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_106[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_105[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_575[_pt_data_90[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_102[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_90[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_575[_pt_data_103[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_104[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_103[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
)
expr = parse(code)
expr = CachedIdentityMapper()(expr) # remove duplicate nodes
replacements = {
"iface_ensm15": Variable("_0"),
"iel_ensm15": Variable("_1"),
"idof_ensm15": Variable("_2"),
}
@optimize_mapper(drop_args=True, drop_kwargs=True,
# inline_cache=True, inline_rec=True,
inline_get_cache_key=True,
print_modified_code_file=sys.stdout)
class Renamer(CachedIdentityMapper):
def map_variable(self, expr):
return replacements.get(expr.name, expr)
def get_cache_key(self, expr):
# Must add 'type(expr)', to differentiate between python scalar types.
# In Python, the following conditions are true: "hash(4) == hash(4.0)"
# and "4 == 4.0", but their traversal results cannot be re-used.
return (type(expr), expr)
def main():
mapper = Renamer()
mapper(expr)
# print(type(new_expr))
if __name__ == "__main__":
from time import time
if 1:
t_start = time()
for _ in range(10_000):
main()
t_end = time()
print(f"Took: {t_end-t_start} secs.")
else:
import pyinstrument
from pyinstrument.renderers import SpeedscopeRenderer
prof = pyinstrument.Profiler()
with prof:
for _ in range(10_000):
main()
with open("ss.json", "w") as outf:
outf.write(prof.output(SpeedscopeRenderer(show_all=True)))
from __future__ import division
from __future__ import absolute_import
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
......@@ -24,137 +24,99 @@ THE SOFTWARE.
"""
__doc__ = """
Pymbolic is a simple and extensible package for precise manipulation of
symbolic expressions in Python. It doesn't try to compete with :mod:`sympy` as
a computer algebra system. Pymbolic emphasizes providing an extensible
expression tree and a flexible, extensible way to manipulate it.
A taste of :mod:`pymbolic`
--------------------------
Follow along on a simple example. Let's import :mod:`pymbolic` and create a
symbol, *x* in this case.
.. doctest::
>>> import pymbolic as pmbl
>>> x = pmbl.var("x")
>>> x
Variable('x')
Next, let's create an expression using *x*:
.. doctest::
>>> u = (x+1)**5
>>> u
Power(Sum((Variable('x'), 1)), 5)
>>> print u
(x + 1)**5
Note the two ways an expression can be printed, namely :func:`repr` and
:func:`str`. :mod:`pymbolic` purposefully distinguishes the two.
:mod:`pymbolic` does not perform any manipulations on expressions
you put in. It has a few of those built in, but that's not really the point:
.. doctest::
>>> print pmbl.differentiate(u, 'x')
5*(x + 1)**4
.. _custom-manipulation:
Manipulating expressions
^^^^^^^^^^^^^^^^^^^^^^^^
The point is for you to be able to easily write so-called *mappers* to
manipulate expressions. Suppose we would like all sums replaced by
products:
.. doctest::
>>> from pymbolic.mapper import IdentityMapper
>>> class MyMapper(IdentityMapper):
... def map_sum(self, expr):
... return pmbl.primitives.Product(expr.children)
...
>>> print u
(x + 1)**5
>>> print MyMapper()(u)
(x*1)**5
Custom Objects
^^^^^^^^^^^^^^
You can also easily define your own objects to use inside an expression:
.. doctest::
>>> from pymbolic.primitives import Expression
>>> class FancyOperator(Expression):
... def __init__(self, operand):
... self.operand = operand
...
... def __getinitargs__(self):
... return (self.operand,)
...
... mapper_method = "map_fancy_operator"
...
>>> u
Power(Sum((Variable('x'), 1)), 5)
>>> 17*FancyOperator(u)
Product((17, FancyOperator(Power(Sum((Variable('x'), 1)), 5))))
As a final example, we can now derive from *MyMapper* to multiply all
*FancyOperator* instances by 2.
.. doctest::
>>> class MyMapper2(MyMapper):
... def map_fancy_operator(self, expr):
... return 2*FancyOperator(self.rec(expr.operand))
...
>>> MyMapper2()(FancyOperator(u))
Product((2, FancyOperator(Power(Product((Variable('x'), 1)), 5))))
"""
from functools import partial
from pytools import module_getattr_for_deprecations
from . import compiler, parser, primitives
from .compiler import compile
from .mapper import (
dependency,
differentiator,
distributor,
evaluator,
flattener,
stringifier,
substitutor,
)
from .mapper.differentiator import differentiate, differentiate as diff
from .mapper.distributor import distribute, distribute as expand
from .mapper.evaluator import evaluate, evaluate_kw
from .mapper.flattener import flatten
from .mapper.substitutor import substitute
from .parser import parse
from .primitives import ( # noqa: N813
ExpressionNode,
Variable,
Variable as var,
disable_subscript_by_getitem,
expr_dataclass,
flattened_product,
flattened_sum,
linear_combination,
make_common_subexpression as cse,
make_sym_vector,
quotient,
subscript,
variables,
)
from .typing import (
ArithmeticExpression,
Bool,
Expression,
Expression as _TypingExpression,
Number,
RealNumber,
Scalar,
)
from pymbolic.version import VERSION_TEXT as __version__ # noqa
import pymbolic.parser
import pymbolic.compiler
import pymbolic.mapper.evaluator
import pymbolic.mapper.stringifier
import pymbolic.mapper.dependency
import pymbolic.mapper.substitutor
import pymbolic.mapper.differentiator
import pymbolic.mapper.distributor
import pymbolic.mapper.flattener
import pymbolic.primitives
from pymbolic.polynomial import Polynomial # noqa
var = pymbolic.primitives.Variable
variables = pymbolic.primitives.variables
flattened_sum = pymbolic.primitives.flattened_sum
subscript = pymbolic.primitives.subscript
flattened_product = pymbolic.primitives.flattened_product
quotient = pymbolic.primitives.quotient
linear_combination = pymbolic.primitives.linear_combination
cse = pymbolic.primitives.make_common_subexpression
make_sym_vector = pymbolic.primitives.make_sym_vector
disable_subscript_by_getitem = pymbolic.primitives.disable_subscript_by_getitem
parse = pymbolic.parser.parse
evaluate = pymbolic.mapper.evaluator.evaluate
evaluate_kw = pymbolic.mapper.evaluator.evaluate_kw
compile = pymbolic.compiler.compile
substitute = pymbolic.mapper.substitutor.substitute
diff = differentiate = pymbolic.mapper.differentiator.differentiate
expand = pymbolic.mapper.distributor.distribute
distribute = pymbolic.mapper.distributor.distribute
flatten = pymbolic.mapper.flattener.flatten
__all__ = (
"ArithmeticExpression",
"Bool",
"Expression",
"ExpressionNode",
"Number",
"RealNumber",
"Scalar",
"Variable",
"compile",
"compiler",
"cse",
"dependency",
"diff",
"differentiate",
"differentiator",
"disable_subscript_by_getitem",
"distribute",
"distributor",
"evaluate",
"evaluate_kw",
"evaluator",
"expand",
"expr_dataclass",
"flatten",
"flattened_product",
"flattened_sum",
"flattener",
"linear_combination",
"make_sym_vector",
"parse",
"parser",
"primitives",
"quotient",
"stringifier",
"subscript",
"substitute",
"substitutor",
"var",
"variables",
)
__getattr__ = partial(module_getattr_for_deprecations, __name__, {
"ExpressionT": ("pymbolic.typing.Expression", _TypingExpression, 2026),
"ArithmeticExpressionT": ("ArithmeticExpression", ArithmeticExpression, 2026),
"BoolT": ("Bool", Bool, 2026),
"ScalarT": ("Scalar", Scalar, 2026),
})
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
import six
from six.moves import range
from six.moves import zip
from functools import reduce
"""
.. autofunction:: integer_power
.. autofunction:: extended_euclidean
.. autofunction:: gcd
.. autofunction:: lcm
.. autofunction:: fft
.. autofunction:: ifft
.. autofunction:: sym_fft
.. autofunction:: reduced_row_echelon_form
.. autofunction:: solve_affine_equations_for
"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
......@@ -28,8 +35,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import cmath
from pytools import memoize
import operator
import sys
from typing import TYPE_CHECKING, overload
from warnings import warn
from pytools import MovedFunctionDeprecationWrapper, memoize
if TYPE_CHECKING:
import numpy as np
if getattr(sys, "_BUILDING_SPHINX_DOCS", None):
import numpy as np # noqa: TC002
# {{{ integer powers
......@@ -37,7 +56,7 @@ from pytools import memoize
def integer_power(x, n, one=1):
"""Compute :math:`x^n` using only multiplications.
See also the `C2 wiki <http://c2.com/cgi/wiki?IntegerPowerAlgorithm>`_.
See also the `C2 wiki <https://wiki.c2.com/?IntegerPowerAlgorithm>`__.
"""
assert isinstance(n, int)
......@@ -69,7 +88,7 @@ def extended_euclidean(q, r):
See also the
`Wikipedia article on the Euclidean algorithm
<https://en.wikipedia.org/wiki/Euclidean_algorithm>`_.
<https://en.wikipedia.org/wiki/Euclidean_algorithm>`__.
"""
import pymbolic.traits as traits
......@@ -81,14 +100,14 @@ def extended_euclidean(q, r):
p, a, b = extended_euclidean(r, q)
return p, b, a
Q = 1, 0
R = 0, 1
Q = 1, 0 # noqa
R = 0, 1 # noqa
while r:
quot, t = divmod(q, r)
T = Q[0] - quot*R[0], Q[1] - quot*R[1]
T = Q[0] - quot*R[0], Q[1] - quot*R[1] # noqa
q, r = r, t
Q, R = R, T
Q, R = R, T # noqa: N806
return q, Q[0], Q[1]
......@@ -103,6 +122,7 @@ def gcd_many(*args):
elif len(args) == 1:
return args[0]
else:
from functools import reduce
return reduce(gcd, args)
......@@ -115,23 +135,28 @@ def lcm(q, r):
# {{{ fft
@memoize
def find_factors(N):
def find_factors(n):
from math import sqrt
N1 = 2
max_N1 = int(sqrt(N))+1
while N % N1 != 0 and N1 <= max_N1:
N1 += 1
n1 = 2
max_n1 = int(sqrt(n))+1
while n % n1 != 0 and n1 <= max_n1:
n1 += 1
if N1 > max_N1:
N1 = N
if n1 > max_n1:
n1 = n
N2 = N // N1
n2 = n // n1
return N1, N2
return n1, n2
def fft(x, sign=1, wrap_intermediate=lambda x: x):
def fft(x, sign=1,
wrap_intermediate=None,
*,
wrap_intermediate_with_level=None,
complex_dtype=None,
custom_np=None, level=0):
r"""Computes the Fourier transform of x:
.. math::
......@@ -141,37 +166,81 @@ def fft(x, sign=1, wrap_intermediate=lambda x: x):
where :math:`z = \exp(-2i\pi\operatorname{sign}/n)` and ``n == len(x)``.
Works for all positive *n*.
See also `Wikipedia <http://en.wikipedia.org/wiki/Cooley-Tukey_FFT_algorithm>`_.
See also `Wikipedia <https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm>`__.
"""
# revision 293076305, http://is.gd/1c7PI
# revision 293076305
# https://en.wikipedia.org/w/index.php?title=Cooley-Tukey_FFT_algorithm&oldid=293076305
# {{{ parameter processing
if wrap_intermediate is not None and wrap_intermediate_with_level is not None:
raise TypeError("may specify at most one of wrap_intermediate and "
"wrap_intermediate_with_level")
if wrap_intermediate is not None:
from warnings import warn
warn("wrap_intermediate is deprecated. Use wrap_intermediate_with_level "
"instead. wrap_intermediate will stop working in 2023.",
DeprecationWarning, stacklevel=2)
def wrap_intermediate_with_level(level, x): # pylint: disable=function-redefined
return wrap_intermediate(x)
if wrap_intermediate_with_level is None:
def wrap_intermediate_with_level(level, x):
return x
from math import pi
import numpy
if custom_np is None:
import numpy as custom_np
N = len(x)
if complex_dtype is None:
if x.dtype.kind == "c":
complex_dtype = x.dtype
else:
from warnings import warn
warn("Not supplying complex_dtype is deprecated, falling back "
"to complex128 for now. This will stop working in 2023.",
DeprecationWarning, stacklevel=2)
complex_dtype = custom_np.complex128
complex_dtype = custom_np.dtype(complex_dtype)
# }}}
if N == 1:
n = len(x)
if n == 1:
return x
N1, N2 = find_factors(N)
N1, N2 = find_factors(n) # noqa: N806
scalar_tp = complex_dtype.type
sub_ffts = [
wrap_intermediate(
fft(x[n1::N1], sign, wrap_intermediate)
* numpy.exp(numpy.linspace(0, sign*(-2j)*pi*n1/N1, N2,
endpoint=False)))
wrap_intermediate_with_level(level,
fft(x[n1::N1], sign, wrap_intermediate, custom_np=custom_np,
level=level+1, complex_dtype=complex_dtype)
* custom_np.exp(
sign*-2j*pi*n1/(N1*N2)
* custom_np.arange(0, N2, dtype=complex_dtype)))
for n1 in range(N1)]
return numpy.hstack([
sum(subvec * cmath.exp(sign*(-2j)*pi*n1*k1/N1)
return custom_np.concatenate([
sum(subvec * scalar_tp(custom_np.exp(sign*(-2j)*pi*n1*k1/N1))
for n1, subvec in enumerate(sub_ffts))
for k1 in range(N1)
])
], axis=0)
def ifft(x, wrap_intermediate=lambda x: x):
return (1/len(x))*fft(x, -1, wrap_intermediate)
def ifft(x, wrap_intermediate=None,
*,
wrap_intermediate_with_level=None,
complex_dtype=None,
custom_np=None):
return (1/len(x))*fft(x, sign=-1, wrap_intermediate=wrap_intermediate,
wrap_intermediate_with_level=wrap_intermediate_with_level,
complex_dtype=complex_dtype, custom_np=custom_np)
def sym_fft(x, sign=1):
......@@ -182,7 +251,7 @@ def sym_fft(x, sign=1):
wrappers at opportune points.
"""
from pymbolic.mapper import IdentityMapper, CSECachingMapperMixin
from pymbolic.mapper import CSECachingMapperMixin, IdentityMapper
class NearZeroKiller(CSECachingMapperMixin, IdentityMapper):
map_common_subexpression_uncached = \
......@@ -207,10 +276,12 @@ def sym_fft(x, sign=1):
def wrap_intermediate(x):
if len(x) > 1:
from pymbolic.primitives import CommonSubexpression
from pymbolic.primitives import CommonSubexpression, cse_scope
result = numpy.empty(len(x), dtype=object)
for i, x_i in enumerate(x):
result[i] = CommonSubexpression(x_i)
result[i] = CommonSubexpression(x_i, scope=cse_scope.EVALUATION)
return result
else:
return x
......@@ -222,25 +293,59 @@ def sym_fft(x, sign=1):
# }}}
def csr_matrix_multiply(S, x):
def csr_matrix_multiply(S, x): # noqa
"""Multiplies a :class:`scipy.sparse.csr_matrix` S by an object-array vector x.
"""
h, w = S.shape
h, _w = S.shape
import numpy
result = numpy.empty_like(x)
for i in range(h):
result[i] = sum(S.data[idx]*x[S.indices[idx]]
result[i] = sum(S.data[idx]*x[S.indices[idx]] # pylint:disable=unsupported-assignment-operation
for idx in range(S.indptr[i], S.indptr[i+1]))
return result
# {{{ gaussian elimination
# {{{ reduced_row_echelon_form
@overload
def reduced_row_echelon_form(
mat: np.ndarray,
*, integral: bool | None = None,
) -> np.ndarray:
...
def gaussian_elimination(mat, rhs):
@overload
def reduced_row_echelon_form(
mat: np.ndarray,
rhs: np.ndarray,
*, integral: bool | None = None,
) -> tuple[np.ndarray, np.ndarray]:
...
def reduced_row_echelon_form(
mat: np.ndarray,
rhs: np.ndarray | None = None,
integral: bool | None = None,
) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
m, n = mat.shape
mat = mat.copy()
if rhs is not None:
rhs = rhs.copy()
if integral is None:
warn(
"Not specifying 'integral' is deprecated, please add it as an argument. "
"This will stop being supported in 2025.",
DeprecationWarning, stacklevel=2)
div_func = operator.floordiv if integral else operator.truediv
i = 0
j = 0
......@@ -259,8 +364,9 @@ def gaussian_elimination(mat, rhs):
# swap rows i and nonz
mat[i], mat[nonz_row] = \
(mat[nonz_row].copy(), mat[i].copy())
rhs[i], rhs[nonz_row] = \
(rhs[nonz_row].copy(), rhs[i].copy())
if rhs is not None:
rhs[i], rhs[nonz_row] = \
(rhs[nonz_row].copy(), rhs[i].copy())
for u in range(0, m):
if u == i:
......@@ -269,12 +375,13 @@ def gaussian_elimination(mat, rhs):
# already 0
continue
l = lcm(mat[u, j], mat[i, j])
u_fac = l//mat[u, j]
i_fac = l//mat[i, j]
ell = lcm(mat[u, j], mat[i, j])
u_fac = div_func(ell, mat[u, j])
i_fac = div_func(ell, mat[i, j])
mat[u] = u_fac*mat[u] - i_fac*mat[i]
rhs[u] = u_fac*rhs[u] - i_fac*rhs[i]
if rhs is not None:
rhs[u] = u_fac*rhs[u] - i_fac*rhs[i]
assert mat[u, j] == 0
......@@ -282,20 +389,38 @@ def gaussian_elimination(mat, rhs):
j += 1
for i in range(m):
g = gcd_many(*(
[a for a in mat[i] if a]
+
[a for a in rhs[i] if a]))
if integral:
for i in range(m):
g = gcd_many(*(
[a for a in mat[i] if a]
+
[a for a in rhs[i] if a] if rhs is not None else []))
mat[i] //= g
if rhs is not None:
rhs[i] //= g
import numpy as np
mat[i] //= g
rhs[i] //= g
from pymbolic.mapper.flattener import flatten
vec_flatten = np.vectorize(flatten, otypes=[object])
return mat, rhs
for i in range(m):
mat[i] = vec_flatten(mat[i])
if rhs is not None:
rhs[i] = vec_flatten(rhs[i])
if rhs is None:
return mat
else:
return mat, rhs
# }}}
gaussian_elimination = MovedFunctionDeprecationWrapper(reduced_row_echelon_form, "2025")
# {{{ symbolic (linear) equation solving
def solve_affine_equations_for(unknowns, equations):
......@@ -313,8 +438,8 @@ def solve_affine_equations_for(unknowns, equations):
from pymbolic import var
unknowns = [var(u) for u in unknowns]
unknowns_set = set(unknowns)
unknown_idx_lut = dict((tgt_name, idx)
for idx, tgt_name in enumerate(unknowns))
unknown_idx_lut = {tgt_name: idx
for idx, tgt_name in enumerate(unknowns)}
# Find non-unknown variables, fix order for them
# Last non-unknown is constant.
......@@ -324,8 +449,8 @@ def solve_affine_equations_for(unknowns, equations):
parameters.update(dep_map(rhs) - unknowns_set)
parameters_list = list(parameters)
parameter_idx_lut = dict((var_name, idx)
for idx, var_name in enumerate(parameters_list))
parameter_idx_lut = {var_name: idx
for idx, var_name in enumerate(parameters_list)}
from pymbolic.mapper.coefficient import CoefficientCollector
coeff_coll = CoefficientCollector()
......@@ -337,7 +462,7 @@ def solve_affine_equations_for(unknowns, equations):
for i_eqn, (lhs, rhs) in enumerate(equations):
for lhs_factor, coeffs in [(1, coeff_coll(lhs)), (-1, coeff_coll(rhs))]:
for key, coeff in six.iteritems(coeffs):
for key, coeff in coeffs.items():
if key in unknowns_set:
mat[i_eqn, unknown_idx_lut[key]] = lhs_factor*coeff
elif key in parameters:
......@@ -345,11 +470,11 @@ def solve_affine_equations_for(unknowns, equations):
elif key == 1:
rhs_mat[i_eqn, -1] = -lhs_factor*coeff
else:
raise ValueError("key '%s' not understood" % key)
raise ValueError(f"key '{key}' not understood")
# }}}
mat, rhs_mat = gaussian_elimination(mat, rhs_mat)
mat, rhs_mat = reduced_row_echelon_form(mat, rhs_mat, integral=True)
# FIXME /!\ Does not check for overdetermined system.
......@@ -357,27 +482,28 @@ def solve_affine_equations_for(unknowns, equations):
for j, unknown in enumerate(unknowns):
(nonz_row,) = np.where(mat[:, j])
if len(nonz_row) != 1:
raise RuntimeError("cannot uniquely solve for '%s'" % unknown)
raise RuntimeError(f"cannot uniquely solve for '{unknown}'")
(nonz_row,) = nonz_row
if abs(mat[nonz_row, j]) != 1:
raise RuntimeError("division with remainder in linear solve for '%s'"
% unknown)
raise RuntimeError(
f"division with remainder in linear solve for '{unknown}'")
div = mat[nonz_row, j]
unknown_val = int(rhs_mat[nonz_row, -1]) // div
for parameter, coeff in zip(parameters_list, rhs_mat[nonz_row]):
for parameter, coeff in zip(
parameters_list, rhs_mat[nonz_row, :-1], strict=True):
unknown_val += (int(coeff) // div) * parameter
result[unknown] = unknown_val
if 0:
for lhs, rhs in equations:
print(lhs, '=', rhs)
print(lhs, "=", rhs)
print("-------------------")
for lhs, rhs in six.iteritems(result):
print(lhs, '=', rhs)
for lhs, rhs in result.items():
print(lhs, "=", rhs)
return result
......
from __future__ import division
from __future__ import absolute_import
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
......@@ -24,59 +24,26 @@ THE SOFTWARE.
"""
import math
import pymbolic
from pymbolic.mapper.stringifier import (StringifyMapper, PREC_NONE,
PREC_SUM, PREC_POWER)
def _constant_mapper(c):
# work around numpy bug #1137 (locale-sensitive repr)
# http://projects.scipy.org/numpy/ticket/1137
try:
import numpy
except ImportError:
pass
else:
if isinstance(c, numpy.floating):
c = float(c)
elif isinstance(c, numpy.complexfloating):
c = complex(c)
return repr(c)
import pymbolic
from pymbolic.mapper.stringifier import PREC_NONE, StringifyMapper
class CompileMapper(StringifyMapper):
def __init__(self):
StringifyMapper.__init__(self,
constant_mapper=_constant_mapper)
def map_polynomial(self, expr, enclosing_prec):
# Use Horner's scheme to evaluate the polynomial
sbase = self(expr.base, PREC_POWER)
def stringify_exp(exp):
if exp == 0:
return ""
elif exp == 1:
return "*%s" % sbase
else:
return "*%s**%s" % (sbase, exp)
result = ""
rev_data = expr.data[::-1]
for i, (exp, coeff) in enumerate(rev_data):
if i+1 < len(rev_data):
next_exp = rev_data[i+1][0]
else:
next_exp = 0
result = "(%s+%s)%s" % (result, self(coeff, PREC_SUM),
stringify_exp(exp-next_exp))
if enclosing_prec > PREC_SUM and len(expr.data) > 1:
return "(%s)" % result
def map_constant(self, expr, enclosing_prec):
# work around numpy bug #1137 (locale-sensitive repr)
# https://github.com/numpy/numpy/issues/1735
try:
import numpy
except ImportError:
pass
else:
return result
if isinstance(expr, numpy.floating):
expr = float(expr)
elif isinstance(expr, numpy.complexfloating):
expr = complex(expr)
return repr(expr)
def map_numpy_array(self, expr, enclosing_prec):
def stringify_leading_dimension(ary):
......@@ -86,28 +53,30 @@ class CompileMapper(StringifyMapper):
else:
rec = stringify_leading_dimension
return "[%s]" % (", ".join(rec(x) for x in ary))
return "[{}]".format(", ".join(rec(x) for x in ary))
return "numpy.array(%s)" % stringify_leading_dimension(expr)
return "numpy.array({})".format(stringify_leading_dimension(expr))
def map_foreign(self, expr, enclosing_prec):
return StringifyMapper.map_foreign(self, expr, enclosing_prec)
class CompiledExpression(object):
class CompiledExpression:
"""This class encapsulates an expression compiled into Python bytecode
for faster evaluation.
Its instances (unlike plain lambdas) are pickleable.
"""
def __init__(self, expression, variables=[]):
def __init__(self, expression, variables=None):
"""
:arg variables: The first arguments (as strings or
:class:`pymbolic.primitives.Variable` instances) to be used for the
compiled function. All variables used by the expression and not
present here are added in alphabetical order.
present here are added in lexicographic order.
"""
if variables is None:
variables = []
self._compile(expression, variables)
def _compile(self, expression, variables):
......@@ -127,13 +96,13 @@ class CompiledExpression(object):
used_variables = DependencyMapper(
composite_leaves=False)(self._Expression)
used_variables -= set(self._Variables)
used_variables -= set(pymbolic.var(key) for key in list(ctx.keys()))
used_variables -= {pymbolic.var(key) for key in list(ctx.keys())}
used_variables = list(used_variables)
used_variables.sort()
all_variables = self._Variables + used_variables
expr_s = CompileMapper()(self._Expression, PREC_NONE)
func_s = "lambda %s: %s" % (",".join(str(v) for v in all_variables),
func_s = "lambda {}: {}".format(",".join(str(v) for v in all_variables),
expr_s)
self._code = eval(func_s, ctx)
......
from __future__ import division
from __future__ import absolute_import
import six
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
......@@ -27,25 +26,22 @@ THE SOFTWARE.
import pymbolic.primitives as prim
from pymbolic.mapper import IdentityMapper, WalkMapper
COMMUTATIVE_CLASSES = (prim.Sum, prim.Product)
COMMUTATIVE_CLASSES = (prim.Sum, prim.Product)
class NormalizedKeyGetter(object):
class NormalizedKeyGetter:
def __call__(self, expr):
if isinstance(expr, COMMUTATIVE_CLASSES):
kid_count = {}
for child in expr.children:
kid_count[child] = kid_count.get(child, 0) + 1
return type(expr), frozenset(six.iteritems(kid_count))
return type(expr), frozenset(kid_count.items())
else:
return expr
class UseCountMapper(WalkMapper):
def __init__(self, get_key):
self.subexpr_counts = {}
......@@ -80,9 +76,6 @@ class UseCountMapper(WalkMapper):
self.subexpr_counts[key] = 1
class CSEMapper(IdentityMapper):
def __init__(self, to_eliminate, get_key):
self.to_eliminate = to_eliminate
......@@ -97,8 +90,9 @@ class CSEMapper(IdentityMapper):
try:
return self.canonical_subexprs[key]
except KeyError:
new_expr = prim.wrap_in_cse(
getattr(IdentityMapper, expr.mapper_method)(self, expr))
new_expr = prim.make_common_subexpression(
getattr(IdentityMapper, expr.mapper_method)(self, expr)
)
self.canonical_subexprs[key] = new_expr
return new_expr
......@@ -120,39 +114,39 @@ class CSEMapper(IdentityMapper):
def map_common_subexpression(self, expr):
# Avoid creating CSE(CSE(...))
if type(expr) is prim.CommonSubexpression:
return prim.wrap_in_cse(self.rec(expr.child), expr.prefix)
return prim.make_common_subexpression(
self.rec(expr.child), expr.prefix, expr.scope
)
else:
# expr is of a derived CSE type
result = self.rec(expr.child)
if type(result) is prim.CommonSubexpression:
result = result.child
return type(expr)(result, expr.prefix, **expr.get_extra_properties())
return type(expr)(result, expr.prefix, expr.scope,
**expr.get_extra_properties())
def map_substitution(self, expr):
return type(expr)(
expr.child,
expr.variables,
tuple(self.rec(v) for v in expr.values))
tuple([self.rec(v) for v in expr.values]))
def tag_common_subexpressions(exprs):
get_key = NormalizedKeyGetter()
ucm = UseCountMapper(get_key)
if isinstance(exprs, prim.Expression):
if isinstance(exprs, prim.ExpressionNode):
raise TypeError("exprs should be an iterable of expressions")
for expr in exprs:
ucm(expr)
to_eliminate = set([subexpr_key
for subexpr_key, count in six.iteritems(ucm.subexpr_counts)
if count > 1])
to_eliminate = {subexpr_key
for subexpr_key, count in ucm.subexpr_counts.items()
if count > 1}
cse_mapper = CSEMapper(to_eliminate, get_key)
result = [cse_mapper(expr) for expr in exprs]
return result
from __future__ import division
from __future__ import absolute_import
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
......@@ -24,21 +24,48 @@ THE SOFTWARE.
"""
import pymbolic.primitives as p
import pymbolic.primitives as primitives
def sin(x):
return p.Call(p.Lookup(p.Variable("math"), "sin"), (x,))
def cos(x):
return p.Call(p.Lookup(p.Variable("math"), "cos"), (x,))
def sin(x):
return primitives.Call(primitives.Lookup(primitives.Variable("math"), "sin"), (x,))
def cos(x):
return primitives.Call(primitives.Lookup(primitives.Variable("math"), "cos"), (x,))
def tan(x):
return primitives.Call(primitives.Lookup(primitives.Variable("math"), "tan"), (x,))
return p.Call(p.Lookup(p.Variable("math"), "tan"), (x,))
def log(x):
return primitives.Call(primitives.Lookup(primitives.Variable("math"), "log"), (x,))
return p.Call(p.Lookup(p.Variable("math"), "log"), (x,))
def exp(x):
return primitives.Call(primitives.Lookup(primitives.Variable("math"), "exp"), (x,))
return p.Call(p.Lookup(p.Variable("math"), "exp"), (x,))
def sinh(x):
return p.Call(p.Lookup(p.Variable("math"), "sinh"), (x,))
def cosh(x):
return p.Call(p.Lookup(p.Variable("math"), "cosh"), (x,))
def tanh(x):
return p.Call(p.Lookup(p.Variable("math"), "tanh"), (x,))
def expm1(x):
return p.Call(p.Lookup(p.Variable("math"), "expm1"), (x,))
def fabs(x):
return p.Call(p.Lookup(p.Variable("math"), "fabs"), (x,))
def sign(x):
return p.Call(p.Lookup(p.Variable("math"), "copysign"), (1, x,))
This diff is collapsed.
# * encoding: utf-8 *
from __future__ import division
from __future__ import absolute_import
from six.moves import range
from six.moves import zip
from __future__ import annotations
__copyright__ = "Copyright (C) 2014 Andreas Kloeckner"
......@@ -28,66 +25,96 @@ THE SOFTWARE.
# This is experimental, undocumented, and could go away any second.
# Consider yourself warned.
from typing import TYPE_CHECKING, ClassVar
from pymbolic.geometric_algebra import MultiVector
import pymbolic.geometric_algebra.primitives as prim
from pymbolic.geometric_algebra import MultiVector
from pymbolic.mapper import (
CombineMapper as CombineMapperBase,
Collector as CollectorBase,
IdentityMapper as IdentityMapperBase,
WalkMapper as WalkMapperBase,
CachingMapperMixin
)
CachedMapper,
CollectedT,
Collector as CollectorBase,
CombineMapper as CombineMapperBase,
IdentityMapper as IdentityMapperBase,
P,
ResultT,
WalkMapper as WalkMapperBase,
)
from pymbolic.mapper.constant_folder import (
ConstantFoldingMapper as ConstantFoldingMapperBase)
from pymbolic.mapper.graphviz import (
GraphvizMapper as GraphvizMapperBase)
ConstantFoldingMapper as ConstantFoldingMapperBase,
)
from pymbolic.mapper.evaluator import EvaluationMapper as EvaluationMapperBase
from pymbolic.mapper.graphviz import GraphvizMapper as GraphvizMapperBase
from pymbolic.mapper.stringifier import (
StringifyMapper as StringifyMapperBase,
PREC_NONE
)
from pymbolic.mapper.evaluator import (
EvaluationMapper as EvaluationMapperBase)
PREC_NONE,
StringifyMapper as StringifyMapperBase,
)
class IdentityMapper(IdentityMapperBase):
def map_multivector_variable(self, expr):
if TYPE_CHECKING:
from collections.abc import Set
from pymbolic.primitives import ExpressionNode
class IdentityMapper(IdentityMapperBase[P]):
def map_nabla(
self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs
) -> ExpressionNode:
return expr
def map_nabla_component(self,
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
) -> ExpressionNode:
return expr
map_nabla = map_multivector_variable
map_nabla_component = map_multivector_variable
def map_derivative_source(self,
expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
) -> ExpressionNode:
operand = self.rec(expr.operand, *args, **kwargs)
if operand is expr.operand:
return expr
def map_derivative_source(self, expr):
return type(expr)(self.rec(expr.operand), expr.nabla_id)
return type(expr)(operand, expr.nabla_id)
class CombineMapper(CombineMapperBase):
def map_derivative_source(self, expr):
return self.rec(expr.operand)
class CombineMapper(CombineMapperBase[ResultT, P]):
def map_derivative_source(
self, expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
return self.rec(expr.operand, *args, **kwargs)
class Collector(CollectorBase):
def map_nabla(self, expr):
class Collector(CollectorBase[CollectedT, P]):
def map_nabla(self,
expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs
) -> Set[CollectedT]:
return set()
map_nabla_component = map_nabla
def map_nabla_component(self,
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
) -> Set[CollectedT]:
return set()
class WalkMapper(WalkMapperBase):
def map_nabla(self, expr, *args):
self.visit(expr, *args)
self.post_visit(expr)
class WalkMapper(WalkMapperBase[P]):
def map_nabla(self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> None:
self.visit(expr, *args, **kwargs)
self.post_visit(expr, *args, **kwargs)
def map_nabla_component(self, expr, *args):
self.visit(expr, *args)
self.post_visit(expr)
def map_nabla_component(
self, expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
) -> None:
self.visit(expr, *args, **kwargs)
self.post_visit(expr, *args, **kwargs)
def map_derivative_source(self, expr, *args):
if not self.visit(expr, *args):
def map_derivative_source(
self, expr, *args: P.args, **kwargs: P.kwargs
) -> None:
if not self.visit(expr, *args, **kwargs):
return
self.rec(expr.operand)
self.post_visit(expr)
self.rec(expr.operand, *args, **kwargs)
self.post_visit(expr, *args, **kwargs)
class EvaluationMapper(EvaluationMapperBase):
......@@ -97,38 +124,32 @@ class EvaluationMapper(EvaluationMapperBase):
map_nabla = map_nabla_component
def map_derivative_source(self, expr):
return type(expr)(self.rec(expr.operand), expr.nabla_id)
operand = self.rec(expr.operand)
if operand is expr.operand:
return expr
return type(expr)(operand, expr.nabla_id)
class StringifyMapper(StringifyMapperBase):
AXES = {0: "x", 1: "y", 2: "z"}
class StringifyMapper(StringifyMapperBase[[]]):
AXES: ClassVar[dict[int, str]] = {0: "x", 1: "y", 2: "z"}
def map_nabla(self, expr, enclosing_prec):
import sys
if sys.version_info >= (3,):
return u"∇[%s]" % expr.nabla_id
else:
return r"\/[%s]" % expr.nabla_id
return f"∇[{expr.nabla_id}]"
def map_nabla_component(self, expr, enclosing_prec):
import sys
if sys.version_info >= (3,):
return u"∇%s[%s]" % (
self.AXES.get(expr.ambient_axis, expr.ambient_axis),
expr.nabla_id)
else:
return r"\/%s[%s]" % (
self.AXES.get(expr.ambient_axis, expr.ambient_axis),
expr.nabla_id)
return "∇{}[{}]".format(
self.AXES.get(expr.ambient_axis, expr.ambient_axis),
expr.nabla_id)
def map_derivative_source(self, expr, enclosing_prec):
return r"D[%s](%s)" % (expr.nabla_id, self.rec(expr.operand, PREC_NONE))
return r"D[{}]({})".format(expr.nabla_id, self.rec(expr.operand, PREC_NONE))
class GraphvizMapper(GraphvizMapperBase):
def map_derivative_source(self, expr):
self.lines.append(
"%s [label=\"D[%s]\",shape=ellipse];" % (
'{} [label="D[{}]",shape=ellipse];'.format(
self.get_id(expr), expr.nabla_id))
if not self.visit(expr, node_printed=True):
return
......@@ -146,11 +167,15 @@ class Dimensionalizer(EvaluationMapper):
Dimension of ambient space. Must be provided by subclass.
"""
@property
def ambient_dim(self):
raise NotImplementedError
def map_multivector_variable(self, expr):
from pymbolic.primitives import make_sym_vector
return MultiVector(
make_sym_vector(expr.name, self.ambient_dim,
var_class=type(expr)))
var_factory=type(expr)))
def map_nabla(self, expr):
from pytools.obj_array import make_obj_array
......@@ -166,23 +191,27 @@ class Dimensionalizer(EvaluationMapper):
return rec_op.map(
lambda coeff: DerivativeSource(coeff, expr.nabla_id))
else:
return super(Dimensionalizer, self).map_derivative_source(expr)
return super().map_derivative_source(expr)
# }}}
# {{{ derivative binder
class DerivativeSourceAndNablaComponentCollector(CachingMapperMixin, Collector):
class DerivativeSourceAndNablaComponentCollector(CachedMapper, Collector):
def __init__(self) -> None:
Collector.__init__(self)
CachedMapper.__init__(self)
def map_nabla(self, expr):
raise RuntimeError("DerivativeOccurrenceMapper must be invoked after "
"Dimensionalizer--Nabla found, not allowed")
def map_nabla_component(self, expr):
return set([expr])
return {expr}
def map_derivative_source(self, expr):
return set([expr]) | self.rec(expr.operand)
return {expr} | self.rec(expr.operand)
class NablaComponentToUnitVector(EvaluationMapper):
......@@ -204,7 +233,8 @@ class NablaComponentToUnitVector(EvaluationMapper):
class DerivativeSourceFinder(EvaluationMapper):
"""Recurses down until it finds the :class:`pytential.sym.DerivativeSource`
"""Recurses down until it finds the
:class:`pymbolic.geometric_algebra.primitives.DerivativeSource`
with the right *nabla_id*, then calls :method:`DerivativeBinder.take_derivative`
on the source's argument.
"""
......@@ -227,27 +257,34 @@ class DerivativeBinder(IdentityMapper):
nabla_component_to_unit_vector = NablaComponentToUnitVector
derivative_source_finder = DerivativeSourceFinder
def __init__(self):
def __init__(self, restrict_to_id=None):
self.derivative_collector = \
self.derivative_source_and_nabla_component_collector()
self.restrict_to_id = restrict_to_id
def map_product(self, expr):
# We may write to this below. Make a copy.
children = list(expr.children)
def take_derivative(self, ambient_axis, expr):
raise NotImplementedError
def map_product(self, expr):
# {{{ gather NablaComponents and DerivativeSources
d_source_nabla_ids_per_child = []
# id to set((child index, axis), ...)
nabla_finder = {}
has_d_source_nablas = False
for child_idx, child in enumerate(expr.children):
d_or_ns = self.derivative_collector(child)
if not d_or_ns:
d_source_nabla_ids_per_child.append(set())
continue
for child_idx, rec_child in enumerate(children):
nabla_component_ids = set()
derivative_source_ids = set()
nablas = []
for d_or_n in self.derivative_collector(rec_child):
for d_or_n in d_or_ns:
if isinstance(d_or_n, prim.NablaComponent):
nabla_component_ids.add(d_or_n.nabla_id)
nablas.append(d_or_n)
......@@ -257,20 +294,35 @@ class DerivativeBinder(IdentityMapper):
raise RuntimeError("unexpected result from "
"DerivativeSourceAndNablaComponentCollector")
d_source_nabla_ids_per_child.append(
derivative_source_ids - nabla_component_ids)
d_source_nabla_ids_per_child.append(derivative_source_ids)
if derivative_source_ids:
has_d_source_nablas = True
for ncomp in nablas:
nabla_finder.setdefault(
ncomp.nabla_id, set()).add((child_idx, ncomp.ambient_axis))
if nabla_finder and not any(d_source_nabla_ids_per_child):
raise ValueError(f"no derivative source found to resolve in '{expr}'"
" -- did you forget to wrap the term that should have its "
"derivative taken in 'Derivative()(term)'?")
if not has_d_source_nablas:
rec_children = [self.rec(child) for child in expr.children]
if all(rec_child is child
for rec_child, child in zip(
rec_children, expr.children, strict=True)):
return expr
return type(expr)(tuple(rec_children))
# }}}
# a list of lists, the outer level presenting a sum, the inner a product
result = [children]
result = [list(expr.children)]
for child_idx, (d_source_nabla_ids, child) in enumerate(
zip(d_source_nabla_ids_per_child, children)):
for child_idx, (d_source_nabla_ids, _child) in enumerate(
zip(d_source_nabla_ids_per_child, expr.children, strict=True)):
if not d_source_nabla_ids:
continue
......@@ -284,6 +336,9 @@ class DerivativeBinder(IdentityMapper):
except KeyError:
continue
if self.restrict_to_id is not None and nabla_id != self.restrict_to_id:
continue
n_axes = max(axis for _, axis in nablas) + 1
new_result = []
......@@ -302,11 +357,13 @@ class DerivativeBinder(IdentityMapper):
result = new_result
from pymbolic.primitives import flattened_sum
return flattened_sum(
type(expr)(tuple(
self.rec(prod_term) for prod_term in prod_term_list))
for prod_term_list in result)
from pymbolic.primitives import flattened_product, flattened_sum
return flattened_sum([
flattened_product([
self.rec(prod_term) for prod_term in prod_term_list
])
for prod_term_list in result
])
map_bitwise_xor = map_product
map_bitwise_or = map_product
......@@ -330,12 +387,13 @@ class DerivativeBinder(IdentityMapper):
assert n_axes
from pymbolic.primitives import flattened_sum
return flattened_sum(
return flattened_sum([
self.take_derivative(
axis,
self.nabla_component_to_unit_vector(expr.nabla_id, axis)
(rec_operand))
for axis in range(n_axes))
for axis in range(n_axes)
])
# }}}
......
from __future__ import division
from __future__ import absolute_import
from __future__ import annotations
__copyright__ = "Copyright (C) 2014 Andreas Kloeckner"
......@@ -26,7 +26,15 @@ THE SOFTWARE.
# This is experimental, undocumented, and could go away any second.
# Consider yourself warned.
from pymbolic.primitives import Expression, Variable
from typing import TYPE_CHECKING, ClassVar
from pymbolic.primitives import ExpressionNode, Variable, expr_dataclass
if TYPE_CHECKING:
from collections.abc import Hashable
from pymbolic.typing import Expression
class MultiVectorVariable(Variable):
......@@ -35,61 +43,55 @@ class MultiVectorVariable(Variable):
# {{{ geometric calculus
class _GeometricCalculusExpression(Expression):
class _GeometricCalculusExpression(ExpressionNode):
def stringifier(self):
from pymbolic.geometric_algebra.mapper import StringifyMapper
return StringifyMapper
@expr_dataclass()
class NablaComponent(_GeometricCalculusExpression):
def __init__(self, ambient_axis, nabla_id):
self.ambient_axis = ambient_axis
self.nabla_id = nabla_id
def __getinitargs__(self):
return (self.ambient_axis, self.nabla_id)
mapper_method = "map_nabla_component"
ambient_axis: int
nabla_id: Hashable
@expr_dataclass()
class Nabla(_GeometricCalculusExpression):
def __init__(self, nabla_id):
self.nabla_id = nabla_id
def __getinitargs__(self):
return (self.nabla_id,)
def __getitem__(self, index):
if not isinstance(index, int):
raise TypeError("Nabla subscript must be an integer")
return NablaComponent(index, self.nabla_id)
mapper_method = "map_nabla"
nabla_id: Hashable
@expr_dataclass()
class DerivativeSource(_GeometricCalculusExpression):
def __init__(self, operand, nabla_id=None):
self.operand = operand
self.nabla_id = nabla_id
operand: Expression
nabla_id: Hashable
def __getinitargs__(self):
return (self.operand, self.nabla_id)
mapper_method = "map_derivative_source"
class Derivative:
"""This mechanism cannot be used to take more than one derivative at a time.
class Derivative(object):
_next_id = [0]
.. autoproperty:: nabla
.. automethod:: __call__
.. automethod:: dnabla
.. automethod:: resolve
"""
_next_id: ClassVar[list[int]] = [0]
def __init__(self):
self.my_id = "id%s" % self._next_id[0]
self.my_id = f"id{self._next_id[0]}"
self._next_id[0] += 1
@property
def nabla(self):
return Nabla(self.my_id)
def dnabla(self, ambient_dim):
from pytools.obj_array import make_obj_array
from pymbolic.geometric_algebra import MultiVector
return MultiVector(make_obj_array(
[NablaComponent(axis, self.my_id)
for axis in range(ambient_dim)]))
def __call__(self, operand):
from pymbolic.geometric_algebra import MultiVector
if isinstance(operand, MultiVector):
......@@ -98,6 +100,14 @@ class Derivative(object):
else:
return DerivativeSource(operand, self.my_id)
@staticmethod
def resolve(expr):
# This method will need to be overridden by codes using this
# infrastructure to use the appropriate subclass of DerivativeBinder.
from pymbolic.geometric_algebra.mapper import DerivativeBinder
return DerivativeBinder()(expr)
# }}}
# vim: foldmethod=marker
"""Imperative program representation"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Matt Wala, Andreas Kloeckner"
......
"""Fusion and other user-facing code transforms"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Matt Wala, Andreas Kloeckner"
......
"""Instruction types"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Matt Wala, Andreas Kloeckner"
......@@ -22,195 +24,16 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from warnings import warn
import six
from pytools import RecordWithoutPickling
# {{{ instruction classes
class Instruction(RecordWithoutPickling):
"""
.. attribute:: depends_on
A :class:`frozenset` of instruction ids that are reuqired to be
executed within this execution context before this instruction can be
executed.
.. attribute:: id
A string, a unique identifier for this instruction.
.. automethod:: get_written_variables
.. automethod:: get_read_variables
"""
def __init__(self, **kwargs):
id = kwargs.pop("id", None)
if id is not None:
id = six.moves.intern(id)
depends_on = frozenset(kwargs.pop("depends_on", []))
super(Instruction, self).__init__(
id=id,
depends_on=depends_on,
**kwargs)
def get_written_variables(self):
"""Returns a :class:`frozenset` of variables being written by this
instruction.
"""
return frozenset()
def get_read_variables(self):
"""Returns a :class:`frozenset` of variables being read by this
instruction.
"""
return frozenset()
def map_expressions(self, mapper, include_lhs=True):
"""Returns a new copy of *self* with all expressions
replaced by ``mapepr(expr)`` for every
:class:`pymbolic.primitives.Expression`
contained in *self*.
"""
return self
def get_dependency_mapper(self, include_calls="descend_args"):
from pymbolic.mapper.dependency import DependencyMapper
return DependencyMapper(
include_subscripts=False,
include_lookups=False,
include_calls=include_calls)
# }}}
# {{{ instruction with condition
class ConditionalInstruction(Instruction):
__doc__ = Instruction.__doc__ + """
.. attribute:: condition
The instruction condition as a :mod:`pymbolic` expression (`True` if the
instruction is unconditionally executed)
"""
def __init__(self, **kwargs):
condition = kwargs.pop("condition", True)
super(ConditionalInstruction, self).__init__(
condition=condition,
**kwargs)
def _condition_printing_suffix(self):
if self.condition is True:
return ""
return " if " + str(self.condition)
def __str__(self):
return (super(ConditionalInstruction, self).__str__()
+ self._condition_printing_suffix())
def get_read_variables(self):
dep_mapper = self.get_dependency_mapper()
return (
super(ConditionalInstruction, self).get_read_variables()
|
frozenset(
dep.name for dep in dep_mapper(self.condition)))
# }}}
# {{{ assignment
class Assignment(Instruction):
"""
.. attribute:: lhs
.. attribute:: rhs
"""
def __init__(self, lhs, rhs, **kwargs):
super(Assignment, self).__init__(
lhs=lhs,
rhs=rhs,
**kwargs)
def get_written_variables(self):
from pymbolic.primitives import Variable, Subscript
if isinstance(self.lhs, Variable):
return frozenset([self.lhs.name])
elif isinstance(self.lhs, Subscript):
assert isinstance(self.lhs.aggregate, Variable)
return frozenset([self.lhs.aggregate.name])
else:
raise TypeError("unexpected type of LHS")
def get_read_variables(self):
result = super(Assignment, self).get_read_variables()
get_deps = self.get_dependency_mapper()
def get_vars(expr):
return frozenset(dep.name for dep in get_deps(self.rhs))
result = get_vars(self.rhs) | get_vars(self.lhs)
return result
def map_expressions(self, mapper, include_lhs=True):
return (super(Assignment, self)
.map_expressions(mapper, include_lhs=include_lhs)
.copy(
lhs=mapper(self.lhs) if include_lhs else self.lhs,
rhs=mapper(self.rhs)))
def __str__(self):
result = "{assignee} <- {expr}".format(
assignee=str(self.lhs),
expr=str(self.rhs),)
return result
# }}}
# {{{ conditional assignment
class ConditionalAssignment(ConditionalInstruction, Assignment):
"""
.. attribute:: assignee
.. attribute:: assignee_subscript
The subscript in :attr:`assignee` which is being assigned.
A tuple, which may be empty, to indicate 'no subscript'.
.. attribute:: expression
.. attribute:: loops
A list of triples *(identifier, start, end)* that the assignment
should be carried out inside of these loops.
No ordering of loop iterations is implied.
The loops will typically be nested outer-to-inner, but a target
may validly use any order it likes.
"""
def map_expressions(self, mapper, include_lhs=True):
return (super(ConditionalAssignment, self)
.map_expressions(mapper, include_lhs=include_lhs)
.copy(condition=mapper(self.condition)))
# }}}
# {{{ nop
class Nop(Instruction):
def __str__(self):
return 'nop'
# }}}
warn("pymbolic.imperative.instruction was imported. This has been renamed "
"to pymbolic.imperative.statement", DeprecationWarning, stacklevel=1)
# vim: foldmethod=marker
from pymbolic.imperative.statement import ( # noqa: F401
Assignment,
ConditionalAssignment,
ConditionalInstruction,
Instruction,
Nop,
)
"""Instruction types"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Matt Wala, Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from sys import intern
from pytools import RecordWithoutPickling
from pymbolic.typing import not_none
# {{{ statemetn classes
class Statement(RecordWithoutPickling):
"""
.. attribute:: depends_on
A :class:`frozenset` of instruction ids that are reuqired to be
executed within this execution context before this instruction can be
executed.
.. attribute:: id
A string, a unique identifier for this instruction.
.. automethod:: get_written_variables
.. automethod:: get_read_variables
"""
def __init__(self, **kwargs):
id = kwargs.pop("id", None)
if id is not None:
id = intern(id)
depends_on = frozenset(kwargs.pop("depends_on", []))
super().__init__(
id=id,
depends_on=depends_on,
**kwargs)
def get_written_variables(self):
"""Returns a :class:`frozenset` of variables being written by this
instruction.
"""
return frozenset()
def get_read_variables(self):
"""Returns a :class:`frozenset` of variables being read by this
instruction.
"""
return frozenset()
def map_expressions(self, mapper, include_lhs=True):
"""Returns a new copy of *self* with all expressions
replaced by ``mapepr(expr)`` for every
:class:`pymbolic.primitives.Expression`
contained in *self*.
"""
return self
def get_dependency_mapper(self, include_calls="descend_args"):
from pymbolic.mapper.dependency import DependencyMapper
return DependencyMapper(
include_subscripts=False,
include_lookups=False,
include_calls=include_calls)
# }}}
# {{{ statement with condition
class ConditionalStatement(Statement):
__doc__ = not_none(Statement.__doc__) + """
.. attribute:: condition
The instruction condition as a :mod:`pymbolic` expression (`True` if the
instruction is unconditionally executed)
"""
def __init__(self, **kwargs):
condition = kwargs.pop("condition", True)
super().__init__(
condition=condition,
**kwargs)
def _condition_printing_suffix(self):
if self.condition is True:
return ""
return " if " + str(self.condition)
def __str__(self):
return (super().__str__()
+ self._condition_printing_suffix())
def get_read_variables(self):
dep_mapper = self.get_dependency_mapper()
return (
super().get_read_variables()
| frozenset(
dep.name for dep in dep_mapper(self.condition)))
# }}}
# {{{ assignment
class Assignment(Statement):
"""
.. attribute:: lhs
.. attribute:: rhs
"""
def __init__(self, lhs, rhs, **kwargs):
super().__init__(
lhs=lhs,
rhs=rhs,
**kwargs)
def get_written_variables(self):
from pymbolic.primitives import Subscript, Variable
if isinstance(self.lhs, Variable):
return frozenset([self.lhs.name])
elif isinstance(self.lhs, Subscript):
assert isinstance(self.lhs.aggregate, Variable)
return frozenset([self.lhs.aggregate.name])
else:
raise TypeError("unexpected type of LHS")
def get_read_variables(self):
result = super().get_read_variables()
get_deps = self.get_dependency_mapper()
def get_vars(expr):
return frozenset(dep.name for dep in get_deps(self.rhs))
result = get_vars(self.rhs) | get_vars(self.lhs)
return result
def map_expressions(self, mapper, include_lhs=True):
return (super()
.map_expressions(mapper, include_lhs=include_lhs)
.copy(
lhs=mapper(self.lhs) if include_lhs else self.lhs,
rhs=mapper(self.rhs)))
def __str__(self):
result = "{assignee} <- {expr}".format(
assignee=str(self.lhs),
expr=str(self.rhs),)
return result
# }}}
# {{{ conditional assignment
class ConditionalAssignment(ConditionalStatement, Assignment):
def map_expressions(self, mapper, include_lhs=True):
return (super()
.map_expressions(mapper, include_lhs=include_lhs)
.copy(condition=mapper(self.condition)))
# }}}
# {{{ nop
class Nop(Statement):
def __str__(self):
return "nop"
# }}}
Instruction = Statement
ConditionalInstruction = ConditionalStatement
# vim: foldmethod=marker
"""Imperative program representation: transformations"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Matt Wala, Andreas Kloeckner"
......@@ -23,48 +25,57 @@ THE SOFTWARE.
"""
# {{{ fuse instruction streams
# {{{ fuse statement streams
def fuse_instruction_streams_with_unique_ids(instructions_a, instructions_b):
new_instructions = list(instructions_a)
def fuse_statement_streams_with_unique_ids(statements_a, statements_b):
new_statements = list(statements_a)
from pytools import UniqueNameGenerator
insn_id_gen = UniqueNameGenerator(
set([insna.id for insna in new_instructions]))
stmt_id_gen = UniqueNameGenerator(
{stmta.id for stmta in new_statements})
b_unique_instructions = []
b_unique_statements = []
old_b_id_to_new_b_id = {}
for insnb in instructions_b:
old_id = insnb.id
new_id = insn_id_gen(old_id)
for stmtb in statements_b:
old_id = stmtb.id
new_id = stmt_id_gen(old_id)
old_b_id_to_new_b_id[old_id] = new_id
b_unique_instructions.append(
insnb.copy(id=new_id))
b_unique_statements.append(
stmtb.copy(id=new_id))
for insnb in b_unique_instructions:
new_instructions.append(
insnb.copy(
for stmtb in b_unique_statements:
new_statements.append(
stmtb.copy(
depends_on=frozenset(
old_b_id_to_new_b_id[dep_id]
for dep_id in insnb.depends_on)))
for dep_id in stmtb.depends_on)))
return new_statements, old_b_id_to_new_b_id
def fuse_instruction_streams_with_unique_ids(insns_a, insns_b):
from warnings import warn
warn("fuse_instruction_streams_with_unique_ids has been renamed to "
"fuse_statement_streams_with_unique_ids", DeprecationWarning,
stacklevel=2)
return new_instructions, old_b_id_to_new_b_id
return fuse_statement_streams_with_unique_ids(insns_a, insns_b)
# }}}
# {{{ disambiguate_identifiers
def disambiguate_identifiers(instructions_a, instructions_b,
def disambiguate_identifiers(statements_a, statements_b,
should_disambiguate_name=None):
if should_disambiguate_name is None:
def should_disambiguate_name(name):
def should_disambiguate_name(name): # pylint:disable=function-redefined
return True
from pymbolic.imperative.analysis import get_all_used_identifiers
id_a = get_all_used_identifiers(instructions_a)
id_b = get_all_used_identifiers(instructions_b)
id_a = get_all_used_identifiers(statements_a)
id_b = get_all_used_identifiers(statements_b)
from pytools import UniqueNameGenerator
vng = UniqueNameGenerator(id_a | id_b)
......@@ -76,29 +87,28 @@ def disambiguate_identifiers(instructions_a, instructions_b,
unclash = vng(clash)
subst_b[clash] = var(unclash)
from pymbolic.mapper.substitutor import (
make_subst_func, SubstitutionMapper)
from pymbolic.mapper.substitutor import SubstitutionMapper, make_subst_func
subst_map = SubstitutionMapper(make_subst_func(subst_b))
instructions_b = [
insn.map_expressions(subst_map) for insn in instructions_b]
statements_b = [
stmt.map_expressions(subst_map) for stmt in statements_b]
return instructions_b, subst_b
return statements_b, subst_b
# }}}
# {{{ disambiguate_and_fuse
def disambiguate_and_fuse(instructions_a, instructions_b,
def disambiguate_and_fuse(statements_a, statements_b,
should_disambiguate_name=None):
instructions_b, subst_b = disambiguate_identifiers(
instructions_a, instructions_b,
statements_b, subst_b = disambiguate_identifiers(
statements_a, statements_b,
should_disambiguate_name)
fused, old_b_id_to_new_b_id = \
fuse_instruction_streams_with_unique_ids(
instructions_a, instructions_b)
fuse_statement_streams_with_unique_ids(
statements_a, statements_b)
return fused, subst_b, old_b_id_to_new_b_id
......
from __future__ import division, with_statement
from __future__ import annotations
__copyright__ = """
Copyright (C) 2013 Andreas Kloeckner
......@@ -25,49 +26,82 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import six
import logging
logger = logging.getLogger(__name__)
# {{{ graphviz / dot export
def get_dot_dependency_graph(instructions, use_insn_ids=False,
addtional_lines_hook=None):
"""Return a string in the `dot <http://graphviz.org/>`_ language depicting
dependencies among kernel instructions.
def _default_preamble_hook():
# Sets default attributes for nodes and edges.
yield 'node [shape="box"];'
yield 'edge [dir="back"];'
def get_dot_dependency_graph(
statements, use_stmt_ids=None,
preamble_hook=_default_preamble_hook,
additional_lines_hook=list,
statement_stringifier=None,
# deprecated
use_insn_ids=None,):
"""Return a string in the `dot <https://graphviz.org/>`__ language depicting
dependencies among kernel statements.
:arg statements: A sequence of statements, each of which is stringified by
calling *statement_stringifier*.
:arg statement_stringifier: The function to use for stringifying the
statements. The default stringifier uses :class:`str` and escapes all
double quotes (``"``) in the string representation.
:arg preamble_hook: A function that returns an iterable of lines
to add at the beginning of the graph
:arg additional_lines_hook: A function that returns an iterable
of lines to add at the end of the graph
"""
if statement_stringifier is None:
def statement_stringifier(s):
return str(s).replace('"', r'\"')
if use_stmt_ids is not None and use_insn_ids is not None:
raise TypeError("may not specify both use_stmt_ids and use_insn_ids")
if use_insn_ids is not None:
use_stmt_ids = use_insn_ids
from warnings import warn
warn("'use_insn_ids' is deprecated. Use 'use_stmt_ids' instead.",
DeprecationWarning, stacklevel=2)
def get_node_attrs(stmt):
if use_stmt_ids:
stmt_label = stmt.id
tooltip = statement_stringifier(stmt)
else:
stmt_label = statement_stringifier(stmt)
tooltip = stmt.id
return f'label="{stmt_label}",shape="box",tooltip="{tooltip}"'
lines = []
lines = list(preamble_hook())
lines.append("rankdir=BT;")
dep_graph = {}
# maps (oriented) edge onto annotation string
annotation_dep_graph = {}
for insn in instructions:
if use_insn_ids:
insn_label = insn.id
tooltip = str(insn)
else:
insn_label = str(insn)
tooltip = insn.id
lines.append("\"%s\" [label=\"%s\",shape=\"box\",tooltip=\"%s\"];"
% (
insn.id,
repr(insn_label)[1:-1],
repr(tooltip)[1:-1],
))
for dep in insn.depends_on:
dep_graph.setdefault(insn.id, set()).add(dep)
for stmt in statements:
lines.append('"{}" [{}];'.format(stmt.id, get_node_attrs(stmt)))
for dep in stmt.depends_on:
dep_graph.setdefault(stmt.id, set()).add(dep)
if 0:
for dep in insn.then_depends_on:
annotation_dep_graph[(insn.id, dep)] = "then"
for dep in insn.else_depends_on:
annotation_dep_graph[(insn.id, dep)] = "else"
for dep in stmt.then_depends_on:
annotation_dep_graph[stmt.id, dep] = "then"
for dep in stmt.else_depends_on:
annotation_dep_graph[stmt.id, dep] = "else"
# {{{ O(n^3) (i.e. slow) transitive reduction
......@@ -75,71 +109,49 @@ def get_dot_dependency_graph(instructions, use_insn_ids=False,
while True:
changed_something = False
for insn_1 in dep_graph:
for insn_2 in dep_graph.get(insn_1, set()).copy():
for insn_3 in dep_graph.get(insn_2, set()).copy():
if insn_3 not in dep_graph.get(insn_1, set()):
for stmt_1 in dep_graph:
for stmt_2 in dep_graph.get(stmt_1, set()).copy():
for stmt_3 in dep_graph.get(stmt_2, set()).copy():
if stmt_3 not in dep_graph.get(stmt_1, set()):
changed_something = True
dep_graph[insn_1].add(insn_3)
dep_graph[stmt_1].add(stmt_3)
if not changed_something:
break
for insn_1 in dep_graph:
for insn_2 in dep_graph.get(insn_1, set()).copy():
for insn_3 in dep_graph.get(insn_2, set()).copy():
if insn_3 in dep_graph.get(insn_1, set()):
dep_graph[insn_1].remove(insn_3)
for stmt_1 in dep_graph:
for stmt_2 in dep_graph.get(stmt_1, set()).copy():
for stmt_3 in dep_graph.get(stmt_2, set()).copy():
if stmt_3 in dep_graph.get(stmt_1, set()):
dep_graph[stmt_1].remove(stmt_3)
# }}}
for insn_1 in dep_graph:
for insn_2 in dep_graph.get(insn_1, set()):
lines.append("%s -> %s" % (insn_2, insn_1))
for stmt_1 in dep_graph:
for stmt_2 in dep_graph.get(stmt_1, set()):
lines.append(f"{stmt_1} -> {stmt_2}")
for (insn_1, insn_2), annot in six.iteritems(annotation_dep_graph):
lines.append(
"%s -> %s [label=\"%s\", style=dashed]"
% (insn_2, insn_1, annot))
for (stmt_1, stmt_2), annot in annotation_dep_graph.items():
lines.append(f'{stmt_2} -> {stmt_1} [label="{annot}", style="dashed"]')
if addtional_lines_hook is not None:
lines.extend(addtional_lines_hook())
lines.extend(additional_lines_hook())
return "digraph code {\n%s\n}" % (
"\n".join(lines)
)
return "digraph code {\n%s\n}" % ("\n".join(lines))
# }}}
# {{{ graphviz / dot interactive show
def show_dot(dot_code):
"""Show the graph represented by *dot_code* in a browser.
Can be called on the result of :func:`get_dot_dependency_graph`.
"""
from tempfile import mkdtemp
temp_dir = mkdtemp(prefix="tmp_dagrt_dot")
dot_file_name = "code.dot"
from os.path import join
with open(join(temp_dir, dot_file_name), "w") as dotf:
dotf.write(dot_code)
svg_file_name = "code.svg"
from subprocess import check_call
check_call(["dot", "-Tsvg", "-o", svg_file_name, dot_file_name],
cwd=temp_dir)
full_svg_file_name = join(temp_dir, svg_file_name)
logger.info("show_dot_dependency_graph: svg written to '%s'"
% full_svg_file_name)
from webbrowser import open as browser_open
browser_open("file://" + full_svg_file_name)
def show_dot(dot_code, output_to=None):
from warnings import warn
warn("pymbolic.imperative.utils.show_dot is deprecated. "
"It will stop working in July 2023. "
"Please use pytools.graphviz.show_dot instead.",
DeprecationWarning, stacklevel=2)
from pytools.graphviz import show_dot
return show_dot(dot_code, output_to)
# }}}
# vim: fdm=marker
from __future__ import division, absolute_import, print_function
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Andreas Kloeckner"
__copyright__ = """
Copyright (C) 2015 Andreas Kloeckner
Copyright (C) 2022 Kaushik Kulkarni
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
......@@ -23,7 +27,15 @@ THE SOFTWARE.
"""
import ast
from typing import TYPE_CHECKING, Any, ClassVar
import pymbolic.primitives as p
from pymbolic.mapper import CachedMapper
if TYPE_CHECKING:
from pymbolic.typing import Expression
__doc__ = r'''
......@@ -61,7 +73,7 @@ An example::
'''
class ASTMapper(object):
class ASTMapper:
def __call__(self, expr, *args, **kwargs):
return self.rec(expr, *args, **kwargs)
......@@ -83,24 +95,32 @@ class ASTMapper(object):
def not_supported(self, expr):
raise NotImplementedError(
"%s does not know how to map type '%s'"
% (type(self).__name__,
"{} does not know how to map type '{}'".format(
type(self).__name__,
type(expr).__name__))
# {{{ mapper
class ASTToPymbolic(ASTMapper):
def _add(x, y):
return p.Sum((x, y))
def _add(x, y):
return p.Sum((x, y))
def _sub(x, y):
return p.Sum((x, p.Product(((-1), y))))
def _mult(x, y):
return p.Product((x, y))
def _neg(x):
return -x
def _sub(x, y):
return p.Sum((x, p.Product(((-1), y))))
def _mult(x, y):
return p.Product((x, y))
class ASTToPymbolic(ASTMapper):
bin_op_map = {
bin_op_map: ClassVar[dict[type[ast.operator], Any]] = {
ast.Add: _add,
ast.Sub: _sub,
ast.Mult: _mult,
......@@ -116,43 +136,38 @@ class ASTToPymbolic(ASTMapper):
ast.BitAnd: p.BitwiseAnd,
}
def map_BinOp(self, expr):
def map_BinOp(self, expr): # noqa
try:
op_constructor = self.bin_op_map[type(expr.op)]
except KeyError:
raise NotImplementedError(
"%s does not know how to map operator '%s'"
% (type(self).__name__,
type(expr.op).__name__))
f"{type(self).__name__} does not know how to map operator "
f"'{type(expr.op).__name__}'") from None
return op_constructor(self.rec(expr.left), self.rec(expr.right))
def _neg(x):
return p.Product((-1), x)
unary_op_map = {
unary_op_map: ClassVar[dict[type[ast.unaryop], Any]] = {
ast.Invert: _neg,
ast.Not: p.LogicalNot,
# ast.UAdd:
ast.USub: _neg,
}
def map_UnaryOp(self, expr):
def map_UnaryOp(self, expr): # noqa
try:
op_constructor = self.unary_op_map[expr.op]
op_constructor = self.unary_op_map[type(expr.op)]
except KeyError:
raise NotImplementedError(
"%s does not know how to map operator '%s'"
% (type(self).__name__,
type(expr.op).__name__))
f"{type(self).__name__} does not know how to map operator "
f"'{type(expr.op).__name__}'") from None
return op_constructor(self.rec(expr.left), self.rec(expr.right))
return op_constructor(self.rec(expr.operand))
def map_IfExp(self, expr):
def map_IfExp(self, expr): # noqa
# (expr test, expr body, expr orelse)
return p.If(self.rec(expr.test), self.rec(expr.body), self.rec(expr.orelse))
comparison_op_map = {
comparison_op_map: ClassVar[dict[type[ast.cmpop], str]] = {
ast.Eq: "==",
ast.NotEq: "!=",
ast.Lt: "<",
......@@ -165,7 +180,7 @@ class ASTToPymbolic(ASTMapper):
# NotIn
}
def map_Compare(self, expr):
def map_Compare(self, expr): # noqa
# (expr left, cmpop* ops, expr* comparators)
op, = expr.ops
......@@ -173,46 +188,45 @@ class ASTToPymbolic(ASTMapper):
comp = self.comparison_op_map[type(op)]
except KeyError:
raise NotImplementedError(
"%s does not know how to map operator '%s'"
% (type(self).__name__,
type(op).__name__))
f"{type(self).__name__} does not know how to map operator "
f"'{type(expr.op).__name__}'") from None
# FIXME: Support strung-together comparisons
right, = expr.comparators
return p.Comparison(self.rec(expr.left), comp, self.rec(right))
def map_Call(self, expr):
def map_Call(self, expr): # noqa
# (expr func, expr* args, keyword* keywords)
func = self.rec(expr.func)
args = tuple(self.rec(arg) for arg in expr.args)
if expr.keywords:
args = tuple([self.rec(arg) for arg in expr.args])
if getattr(expr, "keywords", []):
return p.CallWithKwargs(func, args,
dict(
(kw.arg, self.rec(kw.value))
for kw in expr.keywords))
{
kw.arg: self.rec(kw.value)
for kw in expr.keywords})
else:
return p.Call(func, args)
def map_Num(self, expr):
def map_Num(self, expr): # noqa
# (object n) -- a number as a PyObject.
return expr.n
def map_Str(self, expr):
def map_Str(self, expr): # noqa
return expr.s
def map_Bytes(self, expr):
def map_Bytes(self, expr): # noqa
return expr.s
def map_NameConstant(self, expr):
def map_Constant(self, expr): # noqa
# (singleton value)
return expr.value
def map_Attribute(self, expr):
def map_Attribute(self, expr): # noqa
# (expr value, identifier attr, expr_context ctx)
return p.Lookup(self.rec(expr.value), expr.attr)
def map_Subscript(self, expr):
def map_Subscript(self, expr): # noqa
# (expr value, slice slice, expr_context ctx)
def none_or_rec(x):
if x is None:
......@@ -234,13 +248,243 @@ class ASTToPymbolic(ASTMapper):
# def map_Starred(self, expr):
def map_Name(self, expr):
def map_Name(self, expr): # noqa
# (identifier id, expr_context ctx)
return p.Variable(expr.id)
def map_Tuple(self, expr):
def map_Tuple(self, expr): # noqa
# (expr* elts, expr_context ctx)
return tuple(self.rec(ti) for ti in expr.elts)
return tuple([self.rec(ti) for ti in expr.elts])
# }}}
# {{{ PymbolicToASTMapper
class PymbolicToASTMapper(CachedMapper[ast.expr, []]):
def map_variable(self, expr) -> ast.expr:
return ast.Name(id=expr.name)
def _map_multi_children_op(self,
children: tuple[Expression, ...],
op_type: ast.operator) -> ast.expr:
rec_children = [self.rec(child) for child in children]
result = rec_children[-1]
for child in rec_children[-2::-1]:
result = ast.BinOp(child, op_type, result)
return result
def map_sum(self, expr: p.Sum) -> ast.expr:
return self._map_multi_children_op(expr.children, ast.Add())
def map_product(self, expr: p.Product) -> ast.expr:
return self._map_multi_children_op(expr.children, ast.Mult())
def map_constant(self, expr: object) -> ast.expr:
return ast.Constant(expr, None)
def map_call(self, expr: p.Call) -> ast.expr:
return ast.Call(
func=self.rec(expr.function),
args=[self.rec(param) for param in expr.parameters],
keywords=[],
)
def map_call_with_kwargs(self, expr) -> ast.expr:
return ast.Call(
func=self.rec(expr.function),
args=[self.rec(param) for param in expr.parameters],
keywords=[
ast.keyword(
arg=kw,
value=self.rec(param))
for kw, param in sorted(expr.kw_parameters.items())])
def map_subscript(self, expr) -> ast.expr:
return ast.Subscript(value=self.rec(expr.aggregate),
slice=self.rec(expr.index))
def map_lookup(self, expr) -> ast.expr:
return ast.Attribute(self.rec(expr.aggregate),
expr.name)
def map_quotient(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.Div())
def map_floor_div(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.FloorDiv())
def map_remainder(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.Mod())
def map_power(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.base,
expr.exponent),
ast.Pow())
def map_left_shift(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.shiftee,
expr.shift),
ast.LShift())
def map_right_shift(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.RShift())
def map_bitwise_not(self, expr) -> ast.expr:
return ast.UnaryOp(ast.Invert(), self.rec(expr.child))
def map_bitwise_or(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitOr())
def map_bitwise_xor(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitXor())
def map_bitwise_and(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitAnd())
def map_logical_not(self, expr) -> ast.expr:
return ast.UnaryOp(ast.Not(), self.rec(expr.child))
def map_logical_or(self, expr) -> ast.expr:
return ast.BoolOp(ast.Or(), [self.rec(child)
for child in expr.children])
def map_logical_and(self, expr) -> ast.expr:
return ast.BoolOp(ast.And(), [self.rec(child)
for child in expr.children])
def map_list(self, expr: list[Any]) -> ast.expr:
return ast.List([self.rec(el) for el in expr])
def map_tuple(self, expr: tuple[Any, ...]) -> ast.expr:
return ast.Tuple([self.rec(el) for el in expr])
def map_if(self, expr: p.If) -> ast.expr:
return ast.IfExp(test=self.rec(expr.condition),
body=self.rec(expr.then),
orelse=self.rec(expr.else_))
def map_nan(self, expr: p.NaN) -> ast.expr:
assert expr.data_type is not None
if isinstance(expr.data_type(float("nan")), float):
return ast.Call(
ast.Name(id="float"),
args=[ast.Constant("nan")],
keywords=[])
else:
# TODO: would need attributes of NumPy
raise NotImplementedError("Non-float nan not implemented")
def map_slice(self, expr: p.Slice) -> ast.expr:
return ast.Slice(*[None if child is None else self.rec(child)
for child in expr.children])
def map_numpy_array(self, expr) -> ast.expr:
raise NotImplementedError
def map_multivector(self, expr) -> ast.expr:
raise NotImplementedError
def map_common_subexpression(self, expr) -> ast.expr:
raise NotImplementedError
def map_substitution(self, expr) -> ast.expr:
raise NotImplementedError
def map_derivative(self, expr) -> ast.expr:
raise NotImplementedError
def map_if_positive(self, expr) -> ast.expr:
raise NotImplementedError
def map_comparison(self, expr: p.Comparison) -> ast.expr:
raise NotImplementedError
def map_wildcard(self, expr) -> ast.expr:
raise NotImplementedError
def map_dot_wildcard(self, expr) -> ast.expr:
raise NotImplementedError
def map_star_wildcard(self, expr) -> ast.expr:
raise NotImplementedError
def map_function_symbol(self, expr) -> ast.expr:
raise NotImplementedError
def map_min(self, expr) -> ast.expr:
raise NotImplementedError
def map_max(self, expr) -> ast.expr:
raise NotImplementedError
def to_python_ast(expr) -> ast.expr:
"""
Maps *expr* to :class:`ast.expr`.
"""
return PymbolicToASTMapper()(expr)
def to_evaluatable_python_function(expr: Expression,
fn_name: str
) -> str:
"""
Returns a :class:`str` of the python code with a single function *fn_name*
that takes in the variables in *expr* as keyword-only arguments and returns
the evaluated value of *expr*.
.. testsetup::
>>> from pymbolic import parse
>>> from pymbolic.interop.ast import to_evaluatable_python_function
.. doctest::
>>> expr = parse("S//32 + E%32")
>>> print(to_evaluatable_python_function(expr, "foo"))
def foo(*, E, S):
return S // 32 + E % 32
"""
from pymbolic.mapper.dependency import CachedDependencyMapper
dep_mapper: CachedDependencyMapper[[]] = (
CachedDependencyMapper(composite_leaves=True))
deps: list[str] = []
for dep in dep_mapper(expr):
if isinstance(dep, p.Variable):
deps.append(dep.name)
else:
raise NotImplementedError(f"{dep!r} is not supported")
ast_func = ast.FunctionDef(name=fn_name,
args=ast.arguments(args=[],
posonlyargs=[],
kwonlyargs=[ast.arg(dep, None)
for dep in sorted(deps)],
kw_defaults=[None]*len(deps),
vararg=None,
kwarg=None,
defaults=[]),
body=[ast.Return(to_python_ast(expr))],
decorator_list=[])
ast_module = ast.Module([ast_func], type_ignores=[])
return ast.unparse(ast.fix_missing_locations(ast_module))
# }}}
......
This diff is collapsed.
This diff is collapsed.