Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • isuruf/pymbolic
  • inducer/pymbolic
  • xywei/pymbolic
  • wence-/pymbolic
  • kaushikcfd/pymbolic
  • fikl2/pymbolic
  • zweiner2/pymbolic
7 results
Show changes
Showing
with 3607 additions and 279 deletions
Primitives (Basic Objects)
==========================
.. automodule:: pymbolic.typing
.. automodule:: pymbolic.primitives
.. vim: sw=4
#! /bin/sh
rsync --verbose --archive --delete _build/html/ doc-upload:doc/pymbolic
Utilities for dealing with expressions
======================================
Parser
------
.. currentmodule:: pymbolic
.. function:: parse(expr_str)
Return a :class:`pymbolic.primitives.ExpressionNode` tree corresponding
to *expr_str*.
The parser is also relatively easy to extend. See the source code of the following
class.
.. automodule:: pymbolic.parser
.. autoclass:: Parser
Compiler
--------
.. automodule:: pymbolic.compiler
.. autoclass:: CompiledExpression
.. method:: __call__(*args)
Interoperability with other symbolic systems
============================================
Interoperability with :mod:`sympy`
----------------------------------
.. automodule:: pymbolic.interop.sympy
Interoperability with Maxima
----------------------------
.. automodule:: pymbolic.interop.maxima
Interoperability with Python's :mod:`ast` module
------------------------------------------------
.. automodule:: pymbolic.interop.ast
Interoperability with :mod:`matchpy.functions` module
-----------------------------------------------------
.. automodule:: pymbolic.interop.matchpy
Visualizing Expressions
=======================
.. autofunction:: pymbolic.imperative.utils.get_dot_dependency_graph
# See https://github.com/inducer/pymbolic/pull/110 for context
import sys
from pymbolic import parse
from pymbolic.mapper import CachedIdentityMapper
from pymbolic.mapper.optimize import optimize_mapper
from pymbolic.primitives import Variable
code = ("(-1)*((cse_577[_pt_data_48[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
"_pt_data_49[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_48[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_577[_pt_data_46[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_47[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_46[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_577[_pt_data_7[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_43[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_7[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_577[_pt_data_44[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_45[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_44[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_579[_pt_data_68[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_69[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_68[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_579[_pt_data_66[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_67[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_66[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_579[_pt_data_50[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_63[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_50[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_579[_pt_data_64[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_65[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_64[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_581[_pt_data_88[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_89[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_88[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_581[_pt_data_86[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_87[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_86[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_581[_pt_data_70[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_83[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_70[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_581[_pt_data_84[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_85[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_84[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_582[_pt_data_107[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_108[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_107[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_582[_pt_data_105[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
" _pt_data_106[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_105[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_582[_pt_data_90[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_102[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_90[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_582[_pt_data_103[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_104[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_103[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0))"
" + (cse_572[_pt_data_48[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_49[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_48[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0) "
"+ (cse_572[_pt_data_46[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_47[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_46[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0) "
"+ (cse_572[_pt_data_7[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_43[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_7[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_572[_pt_data_44[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_45[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_44[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_573[_pt_data_68[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_69[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_68[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_573[_pt_data_66[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_67[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_66[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_573[_pt_data_50[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_63[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_50[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_573[_pt_data_64[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_65[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_64[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_574[_pt_data_88[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_89[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_88[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_574[_pt_data_86[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_87[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_86[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0) "
"+ (cse_574[_pt_data_70[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_83[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_70[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_574[_pt_data_84[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_85[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_84[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_575[_pt_data_107[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_108[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_107[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_575[_pt_data_105[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_106[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_105[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_575[_pt_data_90[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_102[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_90[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
" + (cse_575[_pt_data_103[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0], _pt_data_104[(iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 10]]"
" if _pt_data_103[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0] != -1 else 0)"
)
expr = parse(code)
expr = CachedIdentityMapper()(expr) # remove duplicate nodes
replacements = {
"iface_ensm15": Variable("_0"),
"iel_ensm15": Variable("_1"),
"idof_ensm15": Variable("_2"),
}
@optimize_mapper(drop_args=True, drop_kwargs=True,
# inline_cache=True, inline_rec=True,
inline_get_cache_key=True,
print_modified_code_file=sys.stdout)
class Renamer(CachedIdentityMapper):
def map_variable(self, expr):
return replacements.get(expr.name, expr)
def get_cache_key(self, expr):
# Must add 'type(expr)', to differentiate between python scalar types.
# In Python, the following conditions are true: "hash(4) == hash(4.0)"
# and "4 == 4.0", but their traversal results cannot be re-used.
return (type(expr), expr)
def main():
mapper = Renamer()
mapper(expr)
# print(type(new_expr))
if __name__ == "__main__":
from time import time
if 1:
t_start = time()
for _ in range(10_000):
main()
t_end = time()
print(f"Took: {t_end-t_start} secs.")
else:
import pyinstrument
from pyinstrument.renderers import SpeedscopeRenderer
prof = pyinstrument.Profiler()
with prof:
for _ in range(10_000):
main()
with open("ss.json", "w") as outf:
outf.write(prof.output(SpeedscopeRenderer(show_all=True)))
import parser
import compiler
import pymbolic.mapper.evaluator
import pymbolic.mapper.stringifier
import pymbolic.mapper.dependency
import pymbolic.mapper.substitutor
import pymbolic.mapper.differentiator
import pymbolic.mapper.expander
import pymbolic.mapper.flattener
import pymbolic.primitives
from pymbolic.polynomial import Polynomial
var = pymbolic.primitives.Variable
variables = pymbolic.primitives.variables
flattened_sum = pymbolic.primitives.flattened_sum
subscript = pymbolic.primitives.subscript
flattened_product = pymbolic.primitives.flattened_product
quotient = pymbolic.primitives.quotient
linear_combination = pymbolic.primitives.linear_combination
cse = pymbolic.primitives.make_common_subexpression
make_sym_vector = pymbolic.primitives.make_sym_vector
parse = pymbolic.parser.parse
evaluate = pymbolic.mapper.evaluator.evaluate
evaluate_kw = pymbolic.mapper.evaluator.evaluate_kw
compile = pymbolic.compiler.compile
substitute = pymbolic.mapper.substitutor.substitute
differentiate = pymbolic.mapper.differentiator.differentiate
expand = pymbolic.mapper.expander.expand
flatten = pymbolic.mapper.flattener.flatten
def simplify(x):
# FIXME: Not yet implemented
return x
def grad(expression, variables):
return [differentiate(expression, var) for var in variables]
def jacobian(expression_list, variables):
return [grad(expr, variables) for expr in expression_list]
def laplace(expression, variables):
return sum(differentiate(differentiate(expression,var), var) for var in variables)
class VectorFunction:
def __init__(self, function_list, variables=[]):
self.FunctionList = [pymbolic.compile(expr, variables=variables)
for expr in function_list]
def __call__(self, x):
import pylinear.array as num
return num.array([ func(x) for func in self.FunctionList ])
class MatrixFunction:
def __init__(self, function_list, variables=[]):
self. FunctionList = [[pymbolic.compile(expr, variables=variables)
for expr in outer]
for outer in function_list]
def __call__(self, x):
import pylinear.array as num
return num.array([[func(x) for func in flist ] for flist in self.FunctionList])
if __name__ == "__main__":
import math
#ex = parse("0 + 4.3e3j * alpha * math.cos(x+math.pi)") + 5
#print ex
#print repr(parse("x+y"))
#print evaluate(ex, {"alpha":5, "math":math, "x":-math.pi})
#compiled = compile(substitute(ex, {var("alpha"): 5}))
#print compiled(-math.pi)
#import cPickle as pickle
#pickle.dumps(compiled)
#print hash(ex)
#print is_constant(ex)
#print substitute(ex, {"alpha": ex})
#ex2 = parse("math.cos(x**2/x)")
#print ex2
#print differentiate(ex2, parse("x"))
x0 = parse("x[0]")
ex = parse("1-x[0]")
print differentiate(ex, x0)
#print expand(ex)
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial
from pytools import module_getattr_for_deprecations
from . import compiler, parser, primitives
from .compiler import compile
from .mapper import (
dependency,
differentiator,
distributor,
evaluator,
flattener,
stringifier,
substitutor,
)
from .mapper.differentiator import differentiate, differentiate as diff
from .mapper.distributor import distribute, distribute as expand
from .mapper.evaluator import evaluate, evaluate_kw
from .mapper.flattener import flatten
from .mapper.substitutor import substitute
from .parser import parse
from .primitives import ( # noqa: N813
ExpressionNode,
Variable,
Variable as var,
disable_subscript_by_getitem,
expr_dataclass,
flattened_product,
flattened_sum,
linear_combination,
make_common_subexpression as cse,
make_sym_vector,
quotient,
subscript,
variables,
)
from .typing import (
ArithmeticExpression,
Bool,
Expression,
Expression as _TypingExpression,
Number,
RealNumber,
Scalar,
)
from pymbolic.version import VERSION_TEXT as __version__ # noqa
__all__ = (
"ArithmeticExpression",
"Bool",
"Expression",
"ExpressionNode",
"Number",
"RealNumber",
"Scalar",
"Variable",
"compile",
"compiler",
"cse",
"dependency",
"diff",
"differentiate",
"differentiator",
"disable_subscript_by_getitem",
"distribute",
"distributor",
"evaluate",
"evaluate_kw",
"evaluator",
"expand",
"expr_dataclass",
"flatten",
"flattened_product",
"flattened_sum",
"flattener",
"linear_combination",
"make_sym_vector",
"parse",
"parser",
"primitives",
"quotient",
"stringifier",
"subscript",
"substitute",
"substitutor",
"var",
"variables",
)
__getattr__ = partial(module_getattr_for_deprecations, __name__, {
"ExpressionT": ("pymbolic.typing.Expression", _TypingExpression, 2026),
"ArithmeticExpressionT": ("ArithmeticExpression", ArithmeticExpression, 2026),
"BoolT": ("Bool", Bool, 2026),
"ScalarT": ("Scalar", Scalar, 2026),
})
from __future__ import division
import cmath
from pytools import memoize
"""
.. autofunction:: integer_power
.. autofunction:: extended_euclidean
.. autofunction:: gcd
.. autofunction:: lcm
.. autofunction:: fft
.. autofunction:: ifft
.. autofunction:: sym_fft
.. autofunction:: reduced_row_echelon_form
.. autofunction:: solve_affine_equations_for
"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import operator
import sys
from typing import TYPE_CHECKING, overload
from warnings import warn
from pytools import MovedFunctionDeprecationWrapper, memoize
if TYPE_CHECKING:
import numpy as np
if getattr(sys, "_BUILDING_SPHINX_DOCS", None):
import numpy as np # noqa: TC002
# {{{ integer powers
def integer_power(x, n, one=1):
# http://c2.com/cgi/wiki?IntegerPowerAlgorithm
"""Compute :math:`x^n` using only multiplications.
See also the `C2 wiki <https://wiki.c2.com/?IntegerPowerAlgorithm>`__.
"""
assert isinstance(n, int)
if n < 0:
raise RuntimeError, "the integer power algorithm does not work for negative numbers"
raise RuntimeError("the integer power algorithm does not "
"work for negative numbers")
aux = one
......@@ -22,28 +75,20 @@ def integer_power(x, n, one=1):
x = x * x
n //= 2
return aux
def gcd(q, r):
return extended_euclidean(q, r)[0]
def gcd_many(*args):
if len(args) == 0:
return 1
elif len(args) == 1:
return args[0]
else:
return reduce(gcd, args)
def lcm(q, r):
return abs(q*r)//gcd(q, r)
# }}}
# {{{ euclidean algorithm
def extended_euclidean(q, r):
"""Return a tuple (p, a, b) such that p = aq + br,
where p is the greatest common divisor.
"""Return a tuple *(p, a, b)* such that :math:`p = aq + br`,
where *p* is the greatest common divisor of *q* and *r*.
See also the
`Wikipedia article on the Euclidean algorithm
<https://en.wikipedia.org/wiki/Euclidean_algorithm>`__.
"""
import pymbolic.traits as traits
......@@ -55,92 +100,159 @@ def extended_euclidean(q, r):
p, a, b = extended_euclidean(r, q)
return p, b, a
Q = 1, 0
R = 0, 1
Q = 1, 0 # noqa
R = 0, 1 # noqa
while r:
quot, t = divmod(q, r)
T = Q[0] - quot*R[0], Q[1] - quot*R[1]
T = Q[0] - quot*R[0], Q[1] - quot*R[1] # noqa
q, r = r, t
Q, R = R, T
Q, R = R, T # noqa: N806
return q, Q[0], Q[1]
def gcd(q, r):
return extended_euclidean(q, r)[0]
def gcd_many(*args):
if len(args) == 0:
return 1
elif len(args) == 1:
return args[0]
else:
from functools import reduce
return reduce(gcd, args)
def lcm(q, r):
return abs(q*r)//gcd(q, r)
# }}}
# {{{ fft
@memoize
def find_factors(N):
def find_factors(n):
from math import sqrt
N1 = 2
max_N1 = int(sqrt(N))+1
while N % N1 != 0 and N1 <= max_N1:
N1 += 1
n1 = 2
max_n1 = int(sqrt(n))+1
while n % n1 != 0 and n1 <= max_n1:
n1 += 1
if N1 > max_N1:
N1 = N
if n1 > max_n1:
n1 = n
N2 = N // N1
n2 = n // n1
return N1, N2
return n1, n2
def fft(x, sign=1,
wrap_intermediate=None,
*,
wrap_intermediate_with_level=None,
complex_dtype=None,
custom_np=None, level=0):
r"""Computes the Fourier transform of x:
.. math::
def fft(x, sign=1, wrap_intermediate=lambda x: x):
"""Computes the Fourier transform of x:
F[x]_k = \sum_{j=0}^{n-1} z^{kj} x_j
F[x]_k = \sum_{j=0}^{n-1} z^{kj} x_j
where :math:`z = \exp(-2i\pi\operatorname{sign}/n)` and ``n == len(x)``.
Works for all positive *n*.
where z = exp(sign*-2j*pi/n) and n = len(x).
See also `Wikipedia <https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm>`__.
"""
# http://en.wikipedia.org/wiki/Cooley-Tukey_FFT_algorithm
# revision 293076305, http://is.gd/1c7PI
# revision 293076305
# https://en.wikipedia.org/w/index.php?title=Cooley-Tukey_FFT_algorithm&oldid=293076305
# {{{ parameter processing
if wrap_intermediate is not None and wrap_intermediate_with_level is not None:
raise TypeError("may specify at most one of wrap_intermediate and "
"wrap_intermediate_with_level")
if wrap_intermediate is not None:
from warnings import warn
warn("wrap_intermediate is deprecated. Use wrap_intermediate_with_level "
"instead. wrap_intermediate will stop working in 2023.",
DeprecationWarning, stacklevel=2)
def wrap_intermediate_with_level(level, x): # pylint: disable=function-redefined
return wrap_intermediate(x)
if wrap_intermediate_with_level is None:
def wrap_intermediate_with_level(level, x):
return x
from math import pi
import numpy
if custom_np is None:
import numpy as custom_np
if complex_dtype is None:
if x.dtype.kind == "c":
complex_dtype = x.dtype
else:
from warnings import warn
warn("Not supplying complex_dtype is deprecated, falling back "
"to complex128 for now. This will stop working in 2023.",
DeprecationWarning, stacklevel=2)
complex_dtype = custom_np.complex128
N = len(x)
complex_dtype = custom_np.dtype(complex_dtype)
if N == 1:
# }}}
n = len(x)
if n == 1:
return x
N1, N2 = find_factors(N)
N1, N2 = find_factors(n) # noqa: N806
scalar_tp = complex_dtype.type
sub_ffts = [
wrap_intermediate(
fft(x[n1::N1], sign, wrap_intermediate)
* numpy.exp(numpy.linspace(0, sign*(-2j)*pi*n1/N1, N2,
endpoint=False)))
wrap_intermediate_with_level(level,
fft(x[n1::N1], sign, wrap_intermediate, custom_np=custom_np,
level=level+1, complex_dtype=complex_dtype)
* custom_np.exp(
sign*-2j*pi*n1/(N1*N2)
* custom_np.arange(0, N2, dtype=complex_dtype)))
for n1 in range(N1)]
return numpy.hstack([
sum(subvec * cmath.exp(sign*(-2j)*pi*n1*k1/N1)
return custom_np.concatenate([
sum(subvec * scalar_tp(custom_np.exp(sign*(-2j)*pi*n1*k1/N1))
for n1, subvec in enumerate(sub_ffts))
for k1 in range(N1)
])
def ifft(x, wrap_intermediate=lambda x:x):
return (1/len(x))*fft(x, -1, wrap_intermediate)
], axis=0)
def ifft(x, wrap_intermediate=None,
*,
wrap_intermediate_with_level=None,
complex_dtype=None,
custom_np=None):
return (1/len(x))*fft(x, sign=-1, wrap_intermediate=wrap_intermediate,
wrap_intermediate_with_level=wrap_intermediate_with_level,
complex_dtype=complex_dtype, custom_np=custom_np)
def sym_fft(x, sign=1):
"""Perform an FFT on the numpy object array x.
"""Perform a (symbolic) FFT on the :mod:`numpy` object array x.
Remove near-zero floating point constants, insert
CommonSubexpression wrappers at opportune points.
:class:`pymbolic.primitives.CommonSubexpression`
wrappers at opportune points.
"""
from pymbolic.mapper import IdentityMapper, CSECachingMapperMixin
from pymbolic.mapper import CSECachingMapperMixin, IdentityMapper
class NearZeroKiller(CSECachingMapperMixin, IdentityMapper):
map_common_subexpression_uncached = \
IdentityMapper.map_common_subexpression
......@@ -164,43 +276,238 @@ def sym_fft(x, sign=1):
def wrap_intermediate(x):
if len(x) > 1:
from pymbolic.primitives import CommonSubexpression
from pymbolic.primitives import CommonSubexpression, cse_scope
result = numpy.empty(len(x), dtype=object)
for i, x_i in enumerate(x):
result[i] = CommonSubexpression(x_i)
result[i] = CommonSubexpression(x_i, scope=cse_scope.EVALUATION)
return result
else:
return x
return NearZeroKiller()(
fft(wrap_intermediate(x), sign=sign, wrap_intermediate=wrap_intermediate))
fft(wrap_intermediate(x), sign=sign,
wrap_intermediate=wrap_intermediate))
# }}}
def csr_matrix_multiply(S, x):
"""Multiplies a scipy.sparse.csr_matrix S by an object-array vector x.
def csr_matrix_multiply(S, x): # noqa
"""Multiplies a :class:`scipy.sparse.csr_matrix` S by an object-array vector x.
"""
h, w = S.shape
h, _w = S.shape
import numpy
result = numpy.empty_like(x)
for i in xrange(h):
result[i] = sum(S.data[idx]*x[S.indices[idx]]
for i in range(h):
result[i] = sum(S.data[idx]*x[S.indices[idx]] # pylint:disable=unsupported-assignment-operation
for idx in range(S.indptr[i], S.indptr[i+1]))
return result
# {{{ reduced_row_echelon_form
@overload
def reduced_row_echelon_form(
mat: np.ndarray,
*, integral: bool | None = None,
) -> np.ndarray:
...
@overload
def reduced_row_echelon_form(
mat: np.ndarray,
rhs: np.ndarray,
*, integral: bool | None = None,
) -> tuple[np.ndarray, np.ndarray]:
...
def reduced_row_echelon_form(
mat: np.ndarray,
rhs: np.ndarray | None = None,
integral: bool | None = None,
) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
m, n = mat.shape
mat = mat.copy()
if rhs is not None:
rhs = rhs.copy()
if integral is None:
warn(
"Not specifying 'integral' is deprecated, please add it as an argument. "
"This will stop being supported in 2025.",
DeprecationWarning, stacklevel=2)
div_func = operator.floordiv if integral else operator.truediv
i = 0
j = 0
while i < m and j < n:
# {{{ find pivot in column j, starting in row i
nonz_row = None
for k in range(i, m):
if mat[k, j]:
nonz_row = k
break
# }}}
if nonz_row is not None:
# swap rows i and nonz
mat[i], mat[nonz_row] = \
(mat[nonz_row].copy(), mat[i].copy())
if rhs is not None:
rhs[i], rhs[nonz_row] = \
(rhs[nonz_row].copy(), rhs[i].copy())
for u in range(0, m):
if u == i:
continue
if not mat[u, j]:
# already 0
continue
ell = lcm(mat[u, j], mat[i, j])
u_fac = div_func(ell, mat[u, j])
i_fac = div_func(ell, mat[i, j])
mat[u] = u_fac*mat[u] - i_fac*mat[i]
if rhs is not None:
rhs[u] = u_fac*rhs[u] - i_fac*rhs[i]
assert mat[u, j] == 0
i += 1
j += 1
if integral:
for i in range(m):
g = gcd_many(*(
[a for a in mat[i] if a]
+
[a for a in rhs[i] if a] if rhs is not None else []))
mat[i] //= g
if rhs is not None:
rhs[i] //= g
import numpy as np
from pymbolic.mapper.flattener import flatten
vec_flatten = np.vectorize(flatten, otypes=[object])
for i in range(m):
mat[i] = vec_flatten(mat[i])
if rhs is not None:
rhs[i] = vec_flatten(rhs[i])
if rhs is None:
return mat
else:
return mat, rhs
# }}}
gaussian_elimination = MovedFunctionDeprecationWrapper(reduced_row_echelon_form, "2025")
# {{{ symbolic (linear) equation solving
def solve_affine_equations_for(unknowns, equations):
"""
:arg unknowns: A list of variable names for which to solve.
:arg equations: A list of tuples ``(lhs, rhs)``.
:return: a dict mapping unknown names to their values.
"""
import numpy as np
from pymbolic.mapper.dependency import DependencyMapper
dep_map = DependencyMapper(composite_leaves=True)
# fix an order for unknowns
from pymbolic import var
unknowns = [var(u) for u in unknowns]
unknowns_set = set(unknowns)
unknown_idx_lut = {tgt_name: idx
for idx, tgt_name in enumerate(unknowns)}
# Find non-unknown variables, fix order for them
# Last non-unknown is constant.
parameters = set()
for lhs, rhs in equations:
parameters.update(dep_map(lhs) - unknowns_set)
parameters.update(dep_map(rhs) - unknowns_set)
parameters_list = list(parameters)
parameter_idx_lut = {var_name: idx
for idx, var_name in enumerate(parameters_list)}
from pymbolic.mapper.coefficient import CoefficientCollector
coeff_coll = CoefficientCollector()
# {{{ build matrix and rhs
mat = np.zeros((len(equations), len(unknowns_set)), dtype=object)
rhs_mat = np.zeros((len(equations), len(parameters)+1), dtype=object)
for i_eqn, (lhs, rhs) in enumerate(equations):
for lhs_factor, coeffs in [(1, coeff_coll(lhs)), (-1, coeff_coll(rhs))]:
for key, coeff in coeffs.items():
if key in unknowns_set:
mat[i_eqn, unknown_idx_lut[key]] = lhs_factor*coeff
elif key in parameters:
rhs_mat[i_eqn, parameter_idx_lut[key]] = -lhs_factor*coeff
elif key == 1:
rhs_mat[i_eqn, -1] = -lhs_factor*coeff
else:
raise ValueError(f"key '{key}' not understood")
# }}}
mat, rhs_mat = reduced_row_echelon_form(mat, rhs_mat, integral=True)
# FIXME /!\ Does not check for overdetermined system.
result = {}
for j, unknown in enumerate(unknowns):
(nonz_row,) = np.where(mat[:, j])
if len(nonz_row) != 1:
raise RuntimeError(f"cannot uniquely solve for '{unknown}'")
(nonz_row,) = nonz_row
if abs(mat[nonz_row, j]) != 1:
raise RuntimeError(
f"division with remainder in linear solve for '{unknown}'")
div = mat[nonz_row, j]
unknown_val = int(rhs_mat[nonz_row, -1]) // div
for parameter, coeff in zip(
parameters_list, rhs_mat[nonz_row, :-1], strict=True):
unknown_val += (int(coeff) // div) * parameter
result[unknown] = unknown_val
if 0:
for lhs, rhs in equations:
print(lhs, "=", rhs)
print("-------------------")
for lhs, rhs in result.items():
print(lhs, "=", rhs)
return result
# }}}
if __name__ == "__main__":
import integer
q = integer.Integer(14)
r = integer.Integer(22)
gcd, a, b = extended_euclidean(q, r)
print gcd, "=", a, "*", q, "+", b, "*", r
print a*q + b*r
# vim: foldmethod=marker
import math
from __future__ import annotations
import pymbolic
from pymbolic.mapper.stringifier import StringifyMapper, PREC_NONE, PREC_SUM, PREC_POWER
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
def _constant_mapper(c):
# work around numpy bug #1137 (locale-sensitive repr)
# http://projects.scipy.org/numpy/ticket/1137
try:
import numpy
except ImportError:
pass
else:
if isinstance(c, numpy.floating):
c = float(c)
elif isinstance(c, numpy.complexfloating):
c = complex(c)
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
return repr(c)
import math
import pymbolic
from pymbolic.mapper.stringifier import PREC_NONE, StringifyMapper
class CompileMapper(StringifyMapper):
def __init__(self):
StringifyMapper.__init__(self,
constant_mapper=_constant_mapper)
def map_polynomial(self, expr, enclosing_prec):
# Use Horner's scheme to evaluate the polynomial
sbase = self(expr.base, PREC_POWER)
def stringify_exp(exp):
if exp == 0:
return ""
elif exp == 1:
return "*%s" % sbase
else:
return "*%s**%s" % (sbase, exp)
result = ""
rev_data = expr.data[::-1]
for i, (exp, coeff) in enumerate(rev_data):
if i+1 < len(rev_data):
next_exp = rev_data[i+1][0]
else:
next_exp = 0
result = "(%s+%s)%s" % (result, self(coeff, PREC_SUM),
stringify_exp(exp-next_exp))
#print "A", result
#print "B", expr
if enclosing_prec > PREC_SUM and len(expr.data) > 1:
return "(%s)" % result
class CompileMapper(StringifyMapper):
def map_constant(self, expr, enclosing_prec):
# work around numpy bug #1137 (locale-sensitive repr)
# https://github.com/numpy/numpy/issues/1735
try:
import numpy
except ImportError:
pass
else:
return result
if isinstance(expr, numpy.floating):
expr = float(expr)
elif isinstance(expr, numpy.complexfloating):
expr = complex(expr)
return repr(expr)
def map_numpy_array(self, expr, enclosing_prec):
def stringify_leading_dimension(ary):
......@@ -64,32 +53,36 @@ class CompileMapper(StringifyMapper):
else:
rec = stringify_leading_dimension
return "[%s]" % (", ".join(rec(x) for x in ary))
return "[{}]".format(", ".join(rec(x) for x in ary))
return "numpy.array(%s)" % stringify_leading_dimension(expr)
return "numpy.array({})".format(stringify_leading_dimension(expr))
def map_foreign(self, expr, enclosing_prec):
return StringifyMapper.map_foreign(self, expr, enclosing_prec)
class CompiledExpression:
"""This class encapsulates a compiled expression.
"""This class encapsulates an expression compiled into Python bytecode
for faster evaluation.
The main reason for its existence is the fact that a dynamically-constructed
lambda function is not picklable.
Its instances (unlike plain lambdas) are pickleable.
"""
def __init__(self, expression, variables = []):
def __init__(self, expression, variables=None):
"""
:arg variables: The first arguments (as strings or
:class:`pymbolic.primitives.Variable` instances) to be used for the
compiled function. All variables used by the expression and not
present here are added in lexicographic order.
"""
if variables is None:
variables = []
self._compile(expression, variables)
def _compile(self, expression, variables):
import pymbolic.primitives as primi
self._Expression = expression
self._Variables = [primi.make_variable(v) for v in variables]
self._compile()
def _compile(self):
ctx = self.context().copy()
try:
......@@ -103,29 +96,27 @@ class CompiledExpression:
used_variables = DependencyMapper(
composite_leaves=False)(self._Expression)
used_variables -= set(self._Variables)
used_variables -= set(pymbolic.var(key) for key in ctx.keys())
used_variables -= {pymbolic.var(key) for key in list(ctx.keys())}
used_variables = list(used_variables)
used_variables.sort()
all_variables = self._Variables + used_variables
expr_s = CompileMapper()(self._Expression, PREC_NONE)
func_s = "lambda %s: %s" % (",".join(str(v) for v in all_variables),
func_s = "lambda {}: {}".format(",".join(str(v) for v in all_variables),
expr_s)
self.__call__ = eval(func_s, ctx)
def __getinitargs__(self):
return self._Expression, self._Variables
self._code = eval(func_s, ctx)
def __getstate__(self):
return None
return self._Expression, self._Variables
def __setstate__(self, state):
pass
self._compile(*state)
def __call__(self, *args):
return self._code(*args)
def context(self):
return {"math": math}
compile = CompiledExpression
from __future__ import division
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import pymbolic.primitives as prim
from pymbolic.mapper import IdentityMapper, WalkMapper
COMMUTATIVE_CLASSES = (prim.Sum, prim.Product)
COMMUTATIVE_CLASSES = (prim.Sum, prim.Product)
class NormalizedKeyGetter(object):
class NormalizedKeyGetter:
def __call__(self, expr):
if isinstance(expr, COMMUTATIVE_CLASSES):
kid_count = {}
for child in expr.children:
kid_count[child] = kid_count.get(child, 0) + 1
return type(expr), frozenset(kid_count.iteritems())
return type(expr), frozenset(kid_count.items())
else:
return expr
class UseCountMapper(WalkMapper):
def __init__(self, get_key):
self.subexpr_counts = {}
......@@ -55,9 +76,6 @@ class UseCountMapper(WalkMapper):
self.subexpr_counts[key] = 1
class CSEMapper(IdentityMapper):
def __init__(self, to_eliminate, get_key):
self.to_eliminate = to_eliminate
......@@ -72,8 +90,9 @@ class CSEMapper(IdentityMapper):
try:
return self.canonical_subexprs[key]
except KeyError:
new_expr = prim.wrap_in_cse(
getattr(IdentityMapper, expr.mapper_method)(self, expr))
new_expr = prim.make_common_subexpression(
getattr(IdentityMapper, expr.mapper_method)(self, expr)
)
self.canonical_subexprs[key] = new_expr
return new_expr
......@@ -95,39 +114,39 @@ class CSEMapper(IdentityMapper):
def map_common_subexpression(self, expr):
# Avoid creating CSE(CSE(...))
if type(expr) is prim.CommonSubexpression:
return prim.wrap_in_cse(self.rec(expr.child), expr.prefix)
return prim.make_common_subexpression(
self.rec(expr.child), expr.prefix, expr.scope
)
else:
# expr is of a derived CSE type
result = self.rec(expr.child)
if type(result) is prim.CommonSubexpression:
result = result.child
return type(expr)(result, expr.prefix, **expr.get_extra_properties())
return type(expr)(result, expr.prefix, expr.scope,
**expr.get_extra_properties())
def map_substitution(self, expr):
return type(expr)(
expr.child,
expr.variables,
tuple(self.rec(v) for v in expr.values))
tuple([self.rec(v) for v in expr.values]))
def tag_common_subexpressions(exprs):
get_key = NormalizedKeyGetter()
ucm = UseCountMapper(get_key)
if isinstance(exprs, prim.Expression):
if isinstance(exprs, prim.ExpressionNode):
raise TypeError("exprs should be an iterable of expressions")
for expr in exprs:
ucm(expr)
to_eliminate = set([subexpr_key
for subexpr_key, count in ucm.subexpr_counts.iteritems()
if count > 1])
to_eliminate = {subexpr_key
for subexpr_key, count in ucm.subexpr_counts.items()
if count > 1}
cse_mapper = CSEMapper(to_eliminate, get_key)
result = [cse_mapper(expr) for expr in exprs]
return result
import pymbolic.primitives as primitives
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import pymbolic.primitives as p
def sin(x):
return primitives.Call(primitives.Lookup(primitives.Variable("math"), "sin"), (x,))
return p.Call(p.Lookup(p.Variable("math"), "sin"), (x,))
def cos(x):
return primitives.Call(primitives.Lookup(primitives.Variable("math"), "cos"), (x,))
return p.Call(p.Lookup(p.Variable("math"), "cos"), (x,))
def tan(x):
return primitives.Call(primitives.Lookup(primitives.Variable("math"), "tan"), (x,))
return p.Call(p.Lookup(p.Variable("math"), "tan"), (x,))
def log(x):
return primitives.Call(primitives.Lookup(primitives.Variable("math"), "log"), (x,))
return p.Call(p.Lookup(p.Variable("math"), "log"), (x,))
def exp(x):
return primitives.Call(primitives.Lookup(primitives.Variable("math"), "exp"), (x,))
return p.Call(p.Lookup(p.Variable("math"), "exp"), (x,))
def sinh(x):
return p.Call(p.Lookup(p.Variable("math"), "sinh"), (x,))
def cosh(x):
return p.Call(p.Lookup(p.Variable("math"), "cosh"), (x,))
def tanh(x):
return p.Call(p.Lookup(p.Variable("math"), "tanh"), (x,))
def expm1(x):
return p.Call(p.Lookup(p.Variable("math"), "expm1"), (x,))
def fabs(x):
return p.Call(p.Lookup(p.Variable("math"), "fabs"), (x,))
def sign(x):
return p.Call(p.Lookup(p.Variable("math"), "copysign"), (1, x,))
This diff is collapsed.
from __future__ import annotations
__copyright__ = "Copyright (C) 2014 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
# This is experimental, undocumented, and could go away any second.
# Consider yourself warned.
from typing import TYPE_CHECKING, ClassVar
import pymbolic.geometric_algebra.primitives as prim
from pymbolic.geometric_algebra import MultiVector
from pymbolic.mapper import (
CachedMapper,
CollectedT,
Collector as CollectorBase,
CombineMapper as CombineMapperBase,
IdentityMapper as IdentityMapperBase,
P,
ResultT,
WalkMapper as WalkMapperBase,
)
from pymbolic.mapper.constant_folder import (
ConstantFoldingMapper as ConstantFoldingMapperBase,
)
from pymbolic.mapper.evaluator import EvaluationMapper as EvaluationMapperBase
from pymbolic.mapper.graphviz import GraphvizMapper as GraphvizMapperBase
from pymbolic.mapper.stringifier import (
PREC_NONE,
StringifyMapper as StringifyMapperBase,
)
if TYPE_CHECKING:
from collections.abc import Set
from pymbolic.primitives import ExpressionNode
class IdentityMapper(IdentityMapperBase[P]):
def map_nabla(
self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs
) -> ExpressionNode:
return expr
def map_nabla_component(self,
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
) -> ExpressionNode:
return expr
def map_derivative_source(self,
expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
) -> ExpressionNode:
operand = self.rec(expr.operand, *args, **kwargs)
if operand is expr.operand:
return expr
return type(expr)(operand, expr.nabla_id)
class CombineMapper(CombineMapperBase[ResultT, P]):
def map_derivative_source(
self, expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
return self.rec(expr.operand, *args, **kwargs)
class Collector(CollectorBase[CollectedT, P]):
def map_nabla(self,
expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs
) -> Set[CollectedT]:
return set()
def map_nabla_component(self,
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
) -> Set[CollectedT]:
return set()
class WalkMapper(WalkMapperBase[P]):
def map_nabla(self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> None:
self.visit(expr, *args, **kwargs)
self.post_visit(expr, *args, **kwargs)
def map_nabla_component(
self, expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
) -> None:
self.visit(expr, *args, **kwargs)
self.post_visit(expr, *args, **kwargs)
def map_derivative_source(
self, expr, *args: P.args, **kwargs: P.kwargs
) -> None:
if not self.visit(expr, *args, **kwargs):
return
self.rec(expr.operand, *args, **kwargs)
self.post_visit(expr, *args, **kwargs)
class EvaluationMapper(EvaluationMapperBase):
def map_nabla_component(self, expr):
return expr
map_nabla = map_nabla_component
def map_derivative_source(self, expr):
operand = self.rec(expr.operand)
if operand is expr.operand:
return expr
return type(expr)(operand, expr.nabla_id)
class StringifyMapper(StringifyMapperBase[[]]):
AXES: ClassVar[dict[int, str]] = {0: "x", 1: "y", 2: "z"}
def map_nabla(self, expr, enclosing_prec):
return f"∇[{expr.nabla_id}]"
def map_nabla_component(self, expr, enclosing_prec):
return "∇{}[{}]".format(
self.AXES.get(expr.ambient_axis, expr.ambient_axis),
expr.nabla_id)
def map_derivative_source(self, expr, enclosing_prec):
return r"D[{}]({})".format(expr.nabla_id, self.rec(expr.operand, PREC_NONE))
class GraphvizMapper(GraphvizMapperBase):
def map_derivative_source(self, expr):
self.lines.append(
'{} [label="D[{}]",shape=ellipse];'.format(
self.get_id(expr), expr.nabla_id))
if not self.visit(expr, node_printed=True):
return
self.rec(expr.operand)
self.post_visit(expr)
# {{{ dimensionalizer
class Dimensionalizer(EvaluationMapper):
"""
.. attribute:: ambient_dim
Dimension of ambient space. Must be provided by subclass.
"""
@property
def ambient_dim(self):
raise NotImplementedError
def map_multivector_variable(self, expr):
from pymbolic.primitives import make_sym_vector
return MultiVector(
make_sym_vector(expr.name, self.ambient_dim,
var_factory=type(expr)))
def map_nabla(self, expr):
from pytools.obj_array import make_obj_array
return MultiVector(make_obj_array(
[prim.NablaComponent(axis, expr.nabla_id)
for axis in range(self.ambient_dim)]))
def map_derivative_source(self, expr):
rec_op = self.rec(expr.operand)
if isinstance(rec_op, MultiVector):
from pymbolic.geometric_algebra.primitives import DerivativeSource
return rec_op.map(
lambda coeff: DerivativeSource(coeff, expr.nabla_id))
else:
return super().map_derivative_source(expr)
# }}}
# {{{ derivative binder
class DerivativeSourceAndNablaComponentCollector(CachedMapper, Collector):
def __init__(self) -> None:
Collector.__init__(self)
CachedMapper.__init__(self)
def map_nabla(self, expr):
raise RuntimeError("DerivativeOccurrenceMapper must be invoked after "
"Dimensionalizer--Nabla found, not allowed")
def map_nabla_component(self, expr):
return {expr}
def map_derivative_source(self, expr):
return {expr} | self.rec(expr.operand)
class NablaComponentToUnitVector(EvaluationMapper):
def __init__(self, nabla_id, ambient_axis):
self.nabla_id = nabla_id
self.ambient_axis = ambient_axis
def map_variable(self, expr):
return expr
def map_nabla_component(self, expr):
if expr.nabla_id == self.nabla_id:
if expr.ambient_axis == self.ambient_axis:
return 1
else:
return 0
else:
return EvaluationMapper.map_nabla_component(self, expr)
class DerivativeSourceFinder(EvaluationMapper):
"""Recurses down until it finds the
:class:`pymbolic.geometric_algebra.primitives.DerivativeSource`
with the right *nabla_id*, then calls :method:`DerivativeBinder.take_derivative`
on the source's argument.
"""
def __init__(self, nabla_id, binder, ambient_axis):
self.nabla_id = nabla_id
self.binder = binder
self.ambient_axis = ambient_axis
def map_derivative_source(self, expr):
if expr.nabla_id == self.nabla_id:
return self.binder.take_derivative(self.ambient_axis, expr.operand)
else:
return EvaluationMapper.map_derivative_source(self, expr)
class DerivativeBinder(IdentityMapper):
derivative_source_and_nabla_component_collector = \
DerivativeSourceAndNablaComponentCollector
nabla_component_to_unit_vector = NablaComponentToUnitVector
derivative_source_finder = DerivativeSourceFinder
def __init__(self, restrict_to_id=None):
self.derivative_collector = \
self.derivative_source_and_nabla_component_collector()
self.restrict_to_id = restrict_to_id
def take_derivative(self, ambient_axis, expr):
raise NotImplementedError
def map_product(self, expr):
# {{{ gather NablaComponents and DerivativeSources
d_source_nabla_ids_per_child = []
# id to set((child index, axis), ...)
nabla_finder = {}
has_d_source_nablas = False
for child_idx, child in enumerate(expr.children):
d_or_ns = self.derivative_collector(child)
if not d_or_ns:
d_source_nabla_ids_per_child.append(set())
continue
nabla_component_ids = set()
derivative_source_ids = set()
nablas = []
for d_or_n in d_or_ns:
if isinstance(d_or_n, prim.NablaComponent):
nabla_component_ids.add(d_or_n.nabla_id)
nablas.append(d_or_n)
elif isinstance(d_or_n, prim.DerivativeSource):
derivative_source_ids.add(d_or_n.nabla_id)
else:
raise RuntimeError("unexpected result from "
"DerivativeSourceAndNablaComponentCollector")
d_source_nabla_ids_per_child.append(derivative_source_ids)
if derivative_source_ids:
has_d_source_nablas = True
for ncomp in nablas:
nabla_finder.setdefault(
ncomp.nabla_id, set()).add((child_idx, ncomp.ambient_axis))
if nabla_finder and not any(d_source_nabla_ids_per_child):
raise ValueError(f"no derivative source found to resolve in '{expr}'"
" -- did you forget to wrap the term that should have its "
"derivative taken in 'Derivative()(term)'?")
if not has_d_source_nablas:
rec_children = [self.rec(child) for child in expr.children]
if all(rec_child is child
for rec_child, child in zip(
rec_children, expr.children, strict=True)):
return expr
return type(expr)(tuple(rec_children))
# }}}
# a list of lists, the outer level presenting a sum, the inner a product
result = [list(expr.children)]
for child_idx, (d_source_nabla_ids, _child) in enumerate(
zip(d_source_nabla_ids_per_child, expr.children, strict=True)):
if not d_source_nabla_ids:
continue
if len(d_source_nabla_ids) > 1:
raise NotImplementedError("more than one DerivativeSource per "
"child in a product")
nabla_id, = d_source_nabla_ids
try:
nablas = nabla_finder[nabla_id]
except KeyError:
continue
if self.restrict_to_id is not None and nabla_id != self.restrict_to_id:
continue
n_axes = max(axis for _, axis in nablas) + 1
new_result = []
for prod_term_list in result:
for axis in range(n_axes):
new_ptl = prod_term_list[:]
dsfinder = self.derivative_source_finder(nabla_id, self, axis)
new_ptl[child_idx] = dsfinder(new_ptl[child_idx])
for nabla_child_index, _ in nablas:
new_ptl[nabla_child_index] = \
self.nabla_component_to_unit_vector(nabla_id, axis)(
new_ptl[nabla_child_index])
new_result.append(new_ptl)
result = new_result
from pymbolic.primitives import flattened_product, flattened_sum
return flattened_sum([
flattened_product([
self.rec(prod_term) for prod_term in prod_term_list
])
for prod_term_list in result
])
map_bitwise_xor = map_product
map_bitwise_or = map_product
map_left_shift = map_product
map_right_shift = map_product
def map_derivative_source(self, expr):
rec_operand = self.rec(expr.operand)
nablas = []
for d_or_n in self.derivative_collector(rec_operand):
if isinstance(d_or_n, prim.NablaComponent):
nablas.append(d_or_n)
elif isinstance(d_or_n, prim.DerivativeSource):
pass
else:
raise RuntimeError("unexpected result from "
"DerivativeSourceAndNablaComponentCollector")
n_axes = max(n.ambient_axis for n in nablas) + 1
assert n_axes
from pymbolic.primitives import flattened_sum
return flattened_sum([
self.take_derivative(
axis,
self.nabla_component_to_unit_vector(expr.nabla_id, axis)
(rec_operand))
for axis in range(n_axes)
])
# }}}
class ConstantFoldingMapper(IdentityMapper, ConstantFoldingMapperBase):
pass
# vim: foldmethod=marker
from __future__ import annotations
__copyright__ = "Copyright (C) 2014 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
# This is experimental, undocumented, and could go away any second.
# Consider yourself warned.
from typing import TYPE_CHECKING, ClassVar
from pymbolic.primitives import ExpressionNode, Variable, expr_dataclass
if TYPE_CHECKING:
from collections.abc import Hashable
from pymbolic.typing import Expression
class MultiVectorVariable(Variable):
mapper_method = "map_multivector_variable"
# {{{ geometric calculus
class _GeometricCalculusExpression(ExpressionNode):
def stringifier(self):
from pymbolic.geometric_algebra.mapper import StringifyMapper
return StringifyMapper
@expr_dataclass()
class NablaComponent(_GeometricCalculusExpression):
ambient_axis: int
nabla_id: Hashable
@expr_dataclass()
class Nabla(_GeometricCalculusExpression):
nabla_id: Hashable
@expr_dataclass()
class DerivativeSource(_GeometricCalculusExpression):
operand: Expression
nabla_id: Hashable
class Derivative:
"""This mechanism cannot be used to take more than one derivative at a time.
.. autoproperty:: nabla
.. automethod:: __call__
.. automethod:: dnabla
.. automethod:: resolve
"""
_next_id: ClassVar[list[int]] = [0]
def __init__(self):
self.my_id = f"id{self._next_id[0]}"
self._next_id[0] += 1
@property
def nabla(self):
return Nabla(self.my_id)
def dnabla(self, ambient_dim):
from pytools.obj_array import make_obj_array
from pymbolic.geometric_algebra import MultiVector
return MultiVector(make_obj_array(
[NablaComponent(axis, self.my_id)
for axis in range(ambient_dim)]))
def __call__(self, operand):
from pymbolic.geometric_algebra import MultiVector
if isinstance(operand, MultiVector):
return operand.map(
lambda coeff: DerivativeSource(coeff, self.my_id))
else:
return DerivativeSource(operand, self.my_id)
@staticmethod
def resolve(expr):
# This method will need to be overridden by codes using this
# infrastructure to use the appropriate subclass of DerivativeBinder.
from pymbolic.geometric_algebra.mapper import DerivativeBinder
return DerivativeBinder()(expr)
# }}}
# vim: foldmethod=marker
"""Imperative program representation"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Matt Wala, Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
"""Fusion and other user-facing code transforms"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Matt Wala, Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
def get_all_used_insn_ids(insn_stream):
return frozenset(insn.id for insn in insn_stream)
def get_all_used_identifiers(insn_stream):
result = set()
for insn in insn_stream:
result |= insn.get_read_variables()
result |= insn.get_written_variables()
return result
"""Instruction types"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Matt Wala, Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from warnings import warn
warn("pymbolic.imperative.instruction was imported. This has been renamed "
"to pymbolic.imperative.statement", DeprecationWarning, stacklevel=1)
from pymbolic.imperative.statement import ( # noqa: F401
Assignment,
ConditionalAssignment,
ConditionalInstruction,
Instruction,
Nop,
)
"""Instruction types"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Matt Wala, Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from sys import intern
from pytools import RecordWithoutPickling
from pymbolic.typing import not_none
# {{{ statemetn classes
class Statement(RecordWithoutPickling):
"""
.. attribute:: depends_on
A :class:`frozenset` of instruction ids that are reuqired to be
executed within this execution context before this instruction can be
executed.
.. attribute:: id
A string, a unique identifier for this instruction.
.. automethod:: get_written_variables
.. automethod:: get_read_variables
"""
def __init__(self, **kwargs):
id = kwargs.pop("id", None)
if id is not None:
id = intern(id)
depends_on = frozenset(kwargs.pop("depends_on", []))
super().__init__(
id=id,
depends_on=depends_on,
**kwargs)
def get_written_variables(self):
"""Returns a :class:`frozenset` of variables being written by this
instruction.
"""
return frozenset()
def get_read_variables(self):
"""Returns a :class:`frozenset` of variables being read by this
instruction.
"""
return frozenset()
def map_expressions(self, mapper, include_lhs=True):
"""Returns a new copy of *self* with all expressions
replaced by ``mapepr(expr)`` for every
:class:`pymbolic.primitives.Expression`
contained in *self*.
"""
return self
def get_dependency_mapper(self, include_calls="descend_args"):
from pymbolic.mapper.dependency import DependencyMapper
return DependencyMapper(
include_subscripts=False,
include_lookups=False,
include_calls=include_calls)
# }}}
# {{{ statement with condition
class ConditionalStatement(Statement):
__doc__ = not_none(Statement.__doc__) + """
.. attribute:: condition
The instruction condition as a :mod:`pymbolic` expression (`True` if the
instruction is unconditionally executed)
"""
def __init__(self, **kwargs):
condition = kwargs.pop("condition", True)
super().__init__(
condition=condition,
**kwargs)
def _condition_printing_suffix(self):
if self.condition is True:
return ""
return " if " + str(self.condition)
def __str__(self):
return (super().__str__()
+ self._condition_printing_suffix())
def get_read_variables(self):
dep_mapper = self.get_dependency_mapper()
return (
super().get_read_variables()
| frozenset(
dep.name for dep in dep_mapper(self.condition)))
# }}}
# {{{ assignment
class Assignment(Statement):
"""
.. attribute:: lhs
.. attribute:: rhs
"""
def __init__(self, lhs, rhs, **kwargs):
super().__init__(
lhs=lhs,
rhs=rhs,
**kwargs)
def get_written_variables(self):
from pymbolic.primitives import Subscript, Variable
if isinstance(self.lhs, Variable):
return frozenset([self.lhs.name])
elif isinstance(self.lhs, Subscript):
assert isinstance(self.lhs.aggregate, Variable)
return frozenset([self.lhs.aggregate.name])
else:
raise TypeError("unexpected type of LHS")
def get_read_variables(self):
result = super().get_read_variables()
get_deps = self.get_dependency_mapper()
def get_vars(expr):
return frozenset(dep.name for dep in get_deps(self.rhs))
result = get_vars(self.rhs) | get_vars(self.lhs)
return result
def map_expressions(self, mapper, include_lhs=True):
return (super()
.map_expressions(mapper, include_lhs=include_lhs)
.copy(
lhs=mapper(self.lhs) if include_lhs else self.lhs,
rhs=mapper(self.rhs)))
def __str__(self):
result = "{assignee} <- {expr}".format(
assignee=str(self.lhs),
expr=str(self.rhs),)
return result
# }}}
# {{{ conditional assignment
class ConditionalAssignment(ConditionalStatement, Assignment):
def map_expressions(self, mapper, include_lhs=True):
return (super()
.map_expressions(mapper, include_lhs=include_lhs)
.copy(condition=mapper(self.condition)))
# }}}
# {{{ nop
class Nop(Statement):
def __str__(self):
return "nop"
# }}}
Instruction = Statement
ConditionalInstruction = ConditionalStatement
# vim: foldmethod=marker
"""Imperative program representation: transformations"""
from __future__ import annotations
__copyright__ = "Copyright (C) 2015 Matt Wala, Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
# {{{ fuse statement streams
def fuse_statement_streams_with_unique_ids(statements_a, statements_b):
new_statements = list(statements_a)
from pytools import UniqueNameGenerator
stmt_id_gen = UniqueNameGenerator(
{stmta.id for stmta in new_statements})
b_unique_statements = []
old_b_id_to_new_b_id = {}
for stmtb in statements_b:
old_id = stmtb.id
new_id = stmt_id_gen(old_id)
old_b_id_to_new_b_id[old_id] = new_id
b_unique_statements.append(
stmtb.copy(id=new_id))
for stmtb in b_unique_statements:
new_statements.append(
stmtb.copy(
depends_on=frozenset(
old_b_id_to_new_b_id[dep_id]
for dep_id in stmtb.depends_on)))
return new_statements, old_b_id_to_new_b_id
def fuse_instruction_streams_with_unique_ids(insns_a, insns_b):
from warnings import warn
warn("fuse_instruction_streams_with_unique_ids has been renamed to "
"fuse_statement_streams_with_unique_ids", DeprecationWarning,
stacklevel=2)
return fuse_statement_streams_with_unique_ids(insns_a, insns_b)
# }}}
# {{{ disambiguate_identifiers
def disambiguate_identifiers(statements_a, statements_b,
should_disambiguate_name=None):
if should_disambiguate_name is None:
def should_disambiguate_name(name): # pylint:disable=function-redefined
return True
from pymbolic.imperative.analysis import get_all_used_identifiers
id_a = get_all_used_identifiers(statements_a)
id_b = get_all_used_identifiers(statements_b)
from pytools import UniqueNameGenerator
vng = UniqueNameGenerator(id_a | id_b)
from pymbolic import var
subst_b = {}
for clash in id_a & id_b:
if should_disambiguate_name(clash):
unclash = vng(clash)
subst_b[clash] = var(unclash)
from pymbolic.mapper.substitutor import SubstitutionMapper, make_subst_func
subst_map = SubstitutionMapper(make_subst_func(subst_b))
statements_b = [
stmt.map_expressions(subst_map) for stmt in statements_b]
return statements_b, subst_b
# }}}
# {{{ disambiguate_and_fuse
def disambiguate_and_fuse(statements_a, statements_b,
should_disambiguate_name=None):
statements_b, subst_b = disambiguate_identifiers(
statements_a, statements_b,
should_disambiguate_name)
fused, old_b_id_to_new_b_id = \
fuse_statement_streams_with_unique_ids(
statements_a, statements_b)
return fused, subst_b, old_b_id_to_new_b_id
# }}}
# vim: foldmethod=marker
This diff is collapsed.
This diff is collapsed.