Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • mattwala/sumpy
  • isuruf/sumpy
  • xywei/sumpy
  • inducer/sumpy
  • fikl2/sumpy
  • ben_sepanski/sumpy
6 results
Show changes
Showing
with 6780 additions and 2017 deletions
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from setuptools import setup
ver_dic = {}
version_file = open("sumpy/version.py")
try:
version_file_contents = version_file.read()
finally:
version_file.close()
os.environ["AKPYTHON_EXEC_FROM_WITHIN_WITHIN_SETUP_PY"] = "1"
exec(compile(version_file_contents, "sumpy/version.py", 'exec'), ver_dic)
# {{{ capture git revision at install time
# authoritative version in pytools/__init__.py
def find_git_revision(tree_root):
# Keep this routine self-contained so that it can be copy-pasted into
# setup.py.
from os.path import join, exists, abspath
tree_root = abspath(tree_root)
if not exists(join(tree_root, ".git")):
return None
from subprocess import Popen, PIPE, STDOUT
p = Popen(["git", "rev-parse", "HEAD"], shell=False,
stdin=PIPE, stdout=PIPE, stderr=STDOUT, close_fds=True,
cwd=tree_root)
(git_rev, _) = p.communicate()
import sys
if sys.version_info >= (3,):
git_rev = git_rev.decode()
git_rev = git_rev.rstrip()
retcode = p.returncode
assert retcode is not None
if retcode != 0:
from warnings import warn
warn("unable to find git revision")
return None
return git_rev
def write_git_revision(package_name):
from os.path import dirname, join
dn = dirname(__file__)
git_rev = find_git_revision(dn)
with open(join(dn, package_name, "_git_rev.py"), "w") as outf:
outf.write("GIT_REVISION = %s\n" % repr(git_rev))
write_git_revision("sumpy")
# }}}
setup(name="sumpy",
version=ver_dic["VERSION_TEXT"],
description="Fast summation in Python",
long_description="""
Code-generating FMM etc.
""",
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'Intended Audience :: Other Audience',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
'Programming Language :: Python',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Information Analysis',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Visualization',
'Topic :: Software Development :: Libraries',
'Topic :: Utilities',
],
author="Andreas Kloeckner",
author_email="inform@tiker.net",
license="MIT",
packages=["sumpy", "sumpy.expansion"],
install_requires=[
"pytools>=2018.2",
"loo.py>=2017.2",
"boxtree>=2018.1",
"pytest>=2.3",
"six",
# If this causes issues, see:
# https://code.google.com/p/sympy/issues/detail?id=3874
"sympy>=0.7.2",
])
from __future__ import division, absolute_import from __future__ import annotations
__copyright__ = "Copyright (C) 2013 Andreas Kloeckner" __copyright__ = "Copyright (C) 2013 Andreas Kloeckner"
...@@ -23,21 +24,45 @@ THE SOFTWARE. ...@@ -23,21 +24,45 @@ THE SOFTWARE.
""" """
import os import os
from sumpy.p2p import P2P, P2PFromCSR from collections.abc import Hashable
from sumpy.p2e import P2EFromSingleBox, P2EFromCSR
from sumpy.e2p import E2PFromSingleBox, E2PFromCSR import loopy as lp
from sumpy.e2e import E2EFromCSR, E2EFromChildren, E2EFromParent
from sumpy.version import VERSION_TEXT
from pytools.persistent_dict import WriteOncePersistentDict from pytools.persistent_dict import WriteOncePersistentDict
__all__ = [ from sumpy.e2e import (
"P2P", "P2PFromCSR", E2EFromChildren,
"P2EFromSingleBox", "P2EFromCSR", E2EFromCSR,
"E2PFromSingleBox", "E2PFromCSR", E2EFromParent,
"E2EFromCSR", "E2EFromChildren", "E2EFromParent"] M2LGenerateTranslationClassesDependentData,
M2LPostprocessLocal,
M2LPreprocessMultipole,
M2LUsingTranslationClassesDependentData,
)
from sumpy.e2p import E2PFromCSR, E2PFromSingleBox
from sumpy.p2e import P2EFromCSR, P2EFromSingleBox
from sumpy.p2p import P2P, P2PFromCSR
from sumpy.version import VERSION_TEXT
code_cache = WriteOncePersistentDict("sumpy-code-cache-v6-"+VERSION_TEXT) __all__ = [
"P2P",
"E2EFromCSR",
"E2EFromChildren",
"E2EFromParent",
"E2PFromCSR",
"E2PFromSingleBox",
"M2LGenerateTranslationClassesDependentData",
"M2LPostprocessLocal",
"M2LPreprocessMultipole",
"M2LUsingTranslationClassesDependentData",
"P2EFromCSR",
"P2EFromSingleBox",
"P2PFromCSR",
]
code_cache: WriteOncePersistentDict[Hashable, lp.TranslationUnit] = \
WriteOncePersistentDict("sumpy-code-cache-v6-"+VERSION_TEXT, safe_sync=False)
# {{{ optimization control # {{{ optimization control
...@@ -63,31 +88,37 @@ CACHING_ENABLED = ( ...@@ -63,31 +88,37 @@ CACHING_ENABLED = (
"SUMPY_NO_CACHE" not in os.environ "SUMPY_NO_CACHE" not in os.environ
and "CG_NO_CACHE" not in os.environ) and "CG_NO_CACHE" not in os.environ)
NO_CACHE_KERNELS = tuple(os.environ.get("SUMPY_NO_CACHE_KERNELS",
"").split(","))
def set_caching_enabled(flag): def set_caching_enabled(flag, no_cache_kernels=()):
"""Set whether :mod:`loopy` is allowed to use disk caching for its various """Set whether :mod:`loopy` is allowed to use disk caching for its various
code generation stages. code generation stages.
""" """
global CACHING_ENABLED global CACHING_ENABLED, NO_CACHE_KERNELS
NO_CACHE_KERNELS = no_cache_kernels
CACHING_ENABLED = flag CACHING_ENABLED = flag
class CacheMode(object): class CacheMode:
"""A context manager for setting whether :mod:`sumpy` is allowed to use """A context manager for setting whether :mod:`sumpy` is allowed to use
disk caches. disk caches.
""" """
def __init__(self, new_flag): def __init__(self, new_flag, new_no_cache_kernels=()):
self.new_flag = new_flag self.new_flag = new_flag
self.new_no_cache_kernels = new_no_cache_kernels
def __enter__(self): def __enter__(self):
global CACHING_ENABLED global CACHING_ENABLED, NO_CACHE_KERNELS
self.previous_mode = CACHING_ENABLED self.previous_flag = CACHING_ENABLED
CACHING_ENABLED = self.new_flag self.previous_kernels = NO_CACHE_KERNELS
set_caching_enabled(self.new_flag, self.new_no_cache_kernels)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
global CACHING_ENABLED set_caching_enabled(self.previous_flag, self.previous_kernels)
CACHING_ENABLED = self.previous_mode del self.previous_flag
del self.previous_mode del self.previous_kernels
# }}} # }}}
from __future__ import annotations
__copyright__ = "Copyright (C) 2022 Alexandru Fikl"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from boxtree.array_context import PyOpenCLArrayContext as PyOpenCLArrayContextBase
from arraycontext.pytest import (
_PytestPyOpenCLArrayContextFactoryWithClass,
register_pytest_array_context_factory,
)
__doc__ = """
Array Context
-------------
.. autoclass:: PyOpenCLArrayContext
"""
# {{{ PyOpenCLArrayContext
class PyOpenCLArrayContext(PyOpenCLArrayContextBase):
def transform_loopy_program(self, t_unit):
default_ep = t_unit.default_entrypoint
options = default_ep.options
if not (options.return_dict and options.no_numpy):
raise ValueError("Loopy kernel passed to call_loopy must "
"have return_dict and no_numpy options set. "
"Did you use arraycontext.make_loopy_program "
"to create this kernel?")
return super().transform_loopy_program(t_unit)
# }}}
# {{{ pytest
def _acf():
import pyopencl as cl
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
return PyOpenCLArrayContext(queue)
class PytestPyOpenCLArrayContextFactory(
_PytestPyOpenCLArrayContextFactoryWithClass):
actx_class = PyOpenCLArrayContext
def __call__(self):
# NOTE: prevent any cache explosions during testing!
from sympy.core.cache import clear_cache
clear_cache()
return super().__call__()
register_pytest_array_context_factory(
"sumpy.pyopencl",
PytestPyOpenCLArrayContextFactory)
# }}}
from __future__ import division from __future__ import annotations
from __future__ import absolute_import
import six
from six.moves import zip
__copyright__ = "Copyright (C) 2012 Andreas Kloeckner" __copyright__ = "Copyright (C) 2012 Andreas Kloeckner"
...@@ -26,9 +24,11 @@ THE SOFTWARE. ...@@ -26,9 +24,11 @@ THE SOFTWARE.
""" """
import logging
import sumpy.symbolic as sym import sumpy.symbolic as sym
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__doc__ = """ __doc__ = """
...@@ -41,7 +41,7 @@ Manipulating batches of assignments ...@@ -41,7 +41,7 @@ Manipulating batches of assignments
""" """
class _SymbolGenerator(object): class _SymbolGenerator:
def __init__(self, taken_symbols): def __init__(self, taken_symbols):
self.taken_symbols = taken_symbols self.taken_symbols = taken_symbols
...@@ -83,7 +83,7 @@ class _SymbolGenerator(object): ...@@ -83,7 +83,7 @@ class _SymbolGenerator(object):
# {{{ collection of assignments # {{{ collection of assignments
class SymbolicAssignmentCollection(object): class SymbolicAssignmentCollection:
"""Represents a collection of assignments:: """Represents a collection of assignments::
a = 5*x a = 5*x
...@@ -109,16 +109,15 @@ class SymbolicAssignmentCollection(object): ...@@ -109,16 +109,15 @@ class SymbolicAssignmentCollection(object):
assignments = {} assignments = {}
self.assignments = assignments self.assignments = assignments
self.reversed_assignments = {v: k for (k, v) in assignments.items()}
self.symbol_generator = _SymbolGenerator(self.assignments) self.symbol_generator = _SymbolGenerator(self.assignments)
self.all_dependencies_cache = {} self.all_dependencies_cache = {}
self.user_symbols = set()
def __str__(self): def __str__(self):
return "\n".join( return "\n".join(
"%s <- %s" % (name, expr) f"{name} <- {expr}"
for name, expr in six.iteritems(self.assignments)) for name, expr in self.assignments.items())
def get_all_dependencies(self, var_name): def get_all_dependencies(self, var_name):
"""Including recursive dependencies.""" """Including recursive dependencies."""
...@@ -144,8 +143,9 @@ class SymbolicAssignmentCollection(object): ...@@ -144,8 +143,9 @@ class SymbolicAssignmentCollection(object):
self.all_dependencies_cache[var_name] = result self.all_dependencies_cache[var_name] = result
return result return result
def add_assignment(self, name, expr, root_name=None, wrt_set=None): def add_assignment(self, name, expr, root_name=None, wrt_set=None,
assert isinstance(name, six.string_types) retain_name=True):
assert isinstance(name, str)
assert name not in self.assignments assert name not in self.assignments
if wrt_set is None: if wrt_set is None:
...@@ -153,7 +153,15 @@ class SymbolicAssignmentCollection(object): ...@@ -153,7 +153,15 @@ class SymbolicAssignmentCollection(object):
if root_name is None: if root_name is None:
root_name = name root_name = name
self.assignments[name] = sym.sympify(expr) new_expr = sym.sympify(expr)
if not retain_name and new_expr in self.reversed_assignments:
return self.reversed_assignments[new_expr]
self.assignments[name] = new_expr
self.reversed_assignments[new_expr] = name
return name
def assign_unique(self, name_base, expr): def assign_unique(self, name_base, expr):
"""Assign *expr* to a new variable whose name is based on *name_base*. """Assign *expr* to a new variable whose name is based on *name_base*.
...@@ -161,17 +169,26 @@ class SymbolicAssignmentCollection(object): ...@@ -161,17 +169,26 @@ class SymbolicAssignmentCollection(object):
""" """
new_name = self.symbol_generator(name_base).name new_name = self.symbol_generator(name_base).name
self.add_assignment(new_name, expr) return self.add_assignment(new_name, expr)
self.user_symbols.add(new_name)
return new_name def assign_temp(self, name_base, expr):
"""If *expr* is mapped to a existing variable, then return the existing
variable or assign *expr* to a new variable whose name is based on
*name_base*. Return the variable name *expr* is mapped to in either case.
"""
new_name = self.symbol_generator(name_base).name
return self.add_assignment(new_name, expr, retain_name=False)
def run_global_cse(self, extra_exprs=None):
if extra_exprs is None:
extra_exprs = []
def run_global_cse(self, extra_exprs=[]):
import time import time
start_time = time.time() start_time = time.time()
logger.info("common subexpression elimination: start") logger.info("common subexpression elimination: start")
assign_names = sorted(self.assignments) assign_names = list(self.assignments.keys())
assign_exprs = [self.assignments[name] for name in assign_names] assign_exprs = [self.assignments[name] for name in assign_names]
# Options here: # Options here:
...@@ -179,7 +196,7 @@ class SymbolicAssignmentCollection(object): ...@@ -179,7 +196,7 @@ class SymbolicAssignmentCollection(object):
# Uses maxima to verify. # Uses maxima to verify.
# - sym.cse: The sympy thing. # - sym.cse: The sympy thing.
# - sumpy.cse.cse: Based on sympy, designed to go faster. # - sumpy.cse.cse: Based on sympy, designed to go faster.
#from sumpy.symbolic import checked_cse # from sumpy.symbolic import checked_cse
from sumpy.cse import cse from sumpy.cse import cse
new_assignments, new_exprs = cse(assign_exprs + extra_exprs, new_assignments, new_exprs = cse(assign_exprs + extra_exprs,
...@@ -188,15 +205,24 @@ class SymbolicAssignmentCollection(object): ...@@ -188,15 +205,24 @@ class SymbolicAssignmentCollection(object):
new_assign_exprs = new_exprs[:len(assign_exprs)] new_assign_exprs = new_exprs[:len(assign_exprs)]
new_extra_exprs = new_exprs[len(assign_exprs):] new_extra_exprs = new_exprs[len(assign_exprs):]
for name, new_expr in zip(assign_names, new_assign_exprs): for name, new_expr in zip(assign_names, new_assign_exprs, strict=True):
self.assignments[name] = new_expr self.assignments[name] = new_expr
for name, value in new_assignments: for name, value in new_assignments:
assert isinstance(name, sym.Symbol) assert isinstance(name, sym.Symbol)
self.add_assignment(name.name, value) self.add_assignment(name.name, value)
logger.info("common subexpression elimination: done after {dur:.2f} s" for name, new_expr in zip(assign_names, new_assign_exprs, strict=True):
.format(dur=time.time() - start_time)) # We want the assignment collection to be ordered correctly
# to make it easier for loopy to schedule.
# Deleting the original assignments and adding them again
# makes them occur after the CSE'd expression preserving
# the order of operations.
del self.assignments[name]
self.assignments[name] = new_expr
logger.info("common subexpression elimination: done after %.2f s",
time.time() - start_time)
return new_extra_exprs return new_extra_exprs
# }}} # }}}
......
This diff is collapsed.
from __future__ import print_function, division from __future__ import annotations
__copyright__ = """ __copyright__ = """
Copyright (C) 2017 Matt Wala Copyright (C) 2017 Matt Wala
...@@ -66,11 +67,17 @@ DAMAGE. ...@@ -66,11 +67,17 @@ DAMAGE.
# }}} # }}}
from sumpy.symbolic import (
Basic, Mul, Add, Pow, Symbol, _coeff_isneg, Derivative, Subs)
from sympy.core.compatibility import iterable
from sympy.utilities.iterables import numbered_symbols from sympy.utilities.iterables import numbered_symbols
from sumpy.symbolic import Add, Basic, Derivative, Mul, Pow, Subs, Symbol, _coeff_isneg
try:
from sympy.utilities.iterables import iterable
except ImportError:
# NOTE: deprecated and moved in sympy 1.10
from sympy.core.compatibility import iterable
__doc__ = """ __doc__ = """
...@@ -98,7 +105,7 @@ def preprocess_for_cse(expr, optimizations): ...@@ -98,7 +105,7 @@ def preprocess_for_cse(expr, optimizations):
:return: The transformed expression. :return: The transformed expression.
""" """
for pre, post in optimizations: for pre, _post in optimizations:
if pre is not None: if pre is not None:
expr = pre(expr) expr = pre(expr)
return expr return expr
...@@ -117,7 +124,7 @@ def postprocess_for_cse(expr, optimizations): ...@@ -117,7 +124,7 @@ def postprocess_for_cse(expr, optimizations):
:return: The transformed expression. :return: The transformed expression.
""" """
for pre, post in reversed(optimizations): for _pre, post in reversed(optimizations):
if post is not None: if post is not None:
expr = post(expr) expr = post(expr)
return expr return expr
...@@ -127,7 +134,7 @@ def postprocess_for_cse(expr, optimizations): ...@@ -127,7 +134,7 @@ def postprocess_for_cse(expr, optimizations):
# {{{ opt cse # {{{ opt cse
class FuncArgTracker(object): class FuncArgTracker:
""" """
A class which manages a mapping from functions to arguments and an inverse A class which manages a mapping from functions to arguments and an inverse
mapping from arguments to functions. mapping from arguments to functions.
...@@ -215,7 +222,7 @@ class FuncArgTracker(object): ...@@ -215,7 +222,7 @@ class FuncArgTracker(object):
if func_i in larger_funcs_container: if func_i in larger_funcs_container:
count_map[func_i] += 1 count_map[func_i] += 1
return dict((k, v) for k, v in count_map.items() if v >= 2) return {k: v for k, v in count_map.items() if v >= 2}
def get_subset_candidates(self, argset, restrict_to_funcset=None): def get_subset_candidates(self, argset, restrict_to_funcset=None):
""" """
...@@ -225,8 +232,7 @@ class FuncArgTracker(object): ...@@ -225,8 +232,7 @@ class FuncArgTracker(object):
""" """
iarg = iter(argset) iarg = iter(argset)
indices = set( indices = set(self.arg_to_funcset[next(iarg)])
fi for fi in self.arg_to_funcset[next(iarg)])
if restrict_to_funcset is not None: if restrict_to_funcset is not None:
indices &= restrict_to_funcset indices &= restrict_to_funcset
...@@ -252,7 +258,7 @@ class FuncArgTracker(object): ...@@ -252,7 +258,7 @@ class FuncArgTracker(object):
self.func_to_argset[func_i].update(new_args) self.func_to_argset[func_i].update(new_args)
class Unevaluated(object): class Unevaluated:
def __init__(self, func, args): def __init__(self, func, args):
self.func = func self.func = func
...@@ -325,7 +331,7 @@ def match_common_args(func_class, funcs, opt_subs): ...@@ -325,7 +331,7 @@ def match_common_args(func_class, funcs, opt_subs):
com_func = Unevaluated( com_func = Unevaluated(
func_class, arg_tracker.get_args_in_value_order(com_args)) func_class, arg_tracker.get_args_in_value_order(com_args))
com_func_number = arg_tracker.get_or_add_value_number(com_func) com_func_number = arg_tracker.get_or_add_value_number(com_func)
arg_tracker.update_func_argset(i, diff_i | set([com_func_number])) arg_tracker.update_func_argset(i, diff_i | {com_func_number})
changed.add(i) changed.add(i)
else: else:
# Treat the whole expression as a CSE. # Treat the whole expression as a CSE.
...@@ -340,13 +346,13 @@ def match_common_args(func_class, funcs, opt_subs): ...@@ -340,13 +346,13 @@ def match_common_args(func_class, funcs, opt_subs):
com_func_number = arg_tracker.get_or_add_value_number(funcs[i]) com_func_number = arg_tracker.get_or_add_value_number(funcs[i])
diff_j = arg_tracker.func_to_argset[j].difference(com_args) diff_j = arg_tracker.func_to_argset[j].difference(com_args)
arg_tracker.update_func_argset(j, diff_j | set([com_func_number])) arg_tracker.update_func_argset(j, diff_j | {com_func_number})
changed.add(j) changed.add(j)
for k in arg_tracker.get_subset_candidates( for k in arg_tracker.get_subset_candidates(
com_args, common_arg_candidates): com_args, common_arg_candidates):
diff_k = arg_tracker.func_to_argset[k].difference(com_args) diff_k = arg_tracker.func_to_argset[k].difference(com_args)
arg_tracker.update_func_argset(k, diff_k | set([com_func_number])) arg_tracker.update_func_argset(k, diff_k | {com_func_number})
changed.add(k) changed.add(k)
if i in changed: if i in changed:
...@@ -364,7 +370,7 @@ def opt_cse(exprs): ...@@ -364,7 +370,7 @@ def opt_cse(exprs):
:arg exprs: A list of sympy expressions: the expressions to optimize. :arg exprs: A list of sympy expressions: the expressions to optimize.
:return: A dictionary of expression substitutions :return: A dictionary of expression substitutions
""" """
opt_subs = dict() opt_subs = {}
from sumpy.tools import OrderedSet from sumpy.tools import OrderedSet
adds = OrderedSet() adds = OrderedSet()
...@@ -444,7 +450,7 @@ def tree_cse(exprs, symbols, opt_subs=None): ...@@ -444,7 +450,7 @@ def tree_cse(exprs, symbols, opt_subs=None):
:return: A pair (replacements, reduced exprs) :return: A pair (replacements, reduced exprs)
""" """
if opt_subs is None: if opt_subs is None:
opt_subs = dict() opt_subs = {}
# {{{ find repeated sub-expressions and used symbols # {{{ find repeated sub-expressions and used symbols
...@@ -454,7 +460,7 @@ def tree_cse(exprs, symbols, opt_subs=None): ...@@ -454,7 +460,7 @@ def tree_cse(exprs, symbols, opt_subs=None):
excluded_symbols = set() excluded_symbols = set()
def find_repeated(expr): def find_repeated(expr):
if not isinstance(expr, (Basic, Unevaluated)): if not isinstance(expr, Basic | Unevaluated):
return return
if isinstance(expr, Basic) and expr.is_Atom: if isinstance(expr, Basic) and expr.is_Atom:
...@@ -475,7 +481,7 @@ def tree_cse(exprs, symbols, opt_subs=None): ...@@ -475,7 +481,7 @@ def tree_cse(exprs, symbols, opt_subs=None):
if expr in opt_subs: if expr in opt_subs:
expr = opt_subs[expr] expr = opt_subs[expr]
if isinstance(expr, CSE_NO_DESCEND_CLASSES): if isinstance(expr, CSE_NO_DESCEND_CLASSES): # noqa: SIM108
args = () args = ()
else: else:
args = expr.args args = expr.args
...@@ -496,10 +502,10 @@ def tree_cse(exprs, symbols, opt_subs=None): ...@@ -496,10 +502,10 @@ def tree_cse(exprs, symbols, opt_subs=None):
replacements = [] replacements = []
subs = dict() subs = {}
def rebuild(expr): def rebuild(expr):
if not isinstance(expr, (Basic, Unevaluated)): if not isinstance(expr, Basic | Unevaluated):
return expr return expr
if not expr.args: if not expr.args:
...@@ -526,7 +532,7 @@ def tree_cse(exprs, symbols, opt_subs=None): ...@@ -526,7 +532,7 @@ def tree_cse(exprs, symbols, opt_subs=None):
try: try:
sym = next(symbols) sym = next(symbols)
except StopIteration: except StopIteration:
raise ValueError("Symbols iterator ran out of symbols.") raise ValueError("Symbols iterator ran out of symbols.") from None
subs[orig_expr] = sym subs[orig_expr] = sym
replacements.append((sym, new_expr)) replacements.append((sym, new_expr))
...@@ -538,7 +544,7 @@ def tree_cse(exprs, symbols, opt_subs=None): ...@@ -538,7 +544,7 @@ def tree_cse(exprs, symbols, opt_subs=None):
reduced_exprs = [] reduced_exprs = []
for e in exprs: for e in exprs:
if isinstance(e, Basic): if isinstance(e, Basic): # noqa: SIM108
reduced_e = rebuild(e) reduced_e = rebuild(e)
else: else:
reduced_e = e reduced_e = e
...@@ -580,7 +586,7 @@ def cse(exprs, symbols=None, optimizations=None): ...@@ -580,7 +586,7 @@ def cse(exprs, symbols=None, optimizations=None):
# Preprocess the expressions to give us better optimization opportunities. # Preprocess the expressions to give us better optimization opportunities.
reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
if symbols is None: if symbols is None: # noqa: SIM108
symbols = numbered_symbols(cls=Symbol) symbols = numbered_symbols(cls=Symbol)
else: else:
# In case we get passed an iterable with an __iter__ method instead of # In case we get passed an iterable with an __iter__ method instead of
......
This diff is collapsed.
from __future__ import annotations
__copyright__ = "Copyright (C) 2022 Hao Gao"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from boxtree.distributed.calculation import DistributedExpansionWrangler
import pyopencl as cl
from sumpy.fmm import SumpyExpansionWrangler
class DistributedSumpyExpansionWrangler(
DistributedExpansionWrangler, SumpyExpansionWrangler):
def __init__(
self, context, comm, tree_indep, local_traversal, global_traversal,
dtype, fmm_level_to_order, communicate_mpoles_via_allreduce=False,
**kwarg):
DistributedExpansionWrangler.__init__(
self, context, comm, global_traversal, True,
communicate_mpoles_via_allreduce=communicate_mpoles_via_allreduce)
SumpyExpansionWrangler.__init__(
self, tree_indep, local_traversal, dtype, fmm_level_to_order, **kwarg)
def distribute_source_weights(self, src_weight_vecs, src_idx_all_ranks):
src_weight_vecs_host = [src_weight.get() for src_weight in src_weight_vecs]
local_src_weight_vecs_host = super().distribute_source_weights(
src_weight_vecs_host, src_idx_all_ranks)
local_src_weight_vecs_device = [
cl.array.to_device(src_weight.queue, local_src_weight)
for local_src_weight, src_weight in
zip(local_src_weight_vecs_host, src_weight_vecs, strict=True)]
return local_src_weight_vecs_device
def gather_potential_results(self, potentials, tgt_idx_all_ranks):
mpi_rank = self.comm.Get_rank()
potentials_host_vec = [potentials_dev.get() for potentials_dev in potentials]
gathered_potentials_host_vec = []
for potentials_host in potentials_host_vec:
gathered_potentials_host_vec.append(
super().gather_potential_results(potentials_host, tgt_idx_all_ranks))
if mpi_rank == 0:
from pytools.obj_array import make_obj_array
return make_obj_array([
cl.array.to_device(potentials_dev.queue, gathered_potentials_host)
for gathered_potentials_host, potentials_dev in
zip(gathered_potentials_host_vec, potentials, strict=True)])
else:
return None
def reorder_sources(self, source_array):
if self.comm.Get_rank() == 0:
return source_array.with_queue(source_array.queue)[
self.global_traversal.tree.user_source_ids]
else:
return source_array
def reorder_potentials(self, potentials):
if self.comm.Get_rank() == 0:
import numpy as np
from pytools.obj_array import obj_array_vectorize
assert (
isinstance(potentials, np.ndarray)
and potentials.dtype.char == "O")
def reorder(x):
return x[self.global_traversal.tree.sorted_target_ids]
return obj_array_vectorize(reorder, potentials)
else:
return None
def communicate_mpoles(self, mpole_exps, return_stats=False):
mpole_exps_host = mpole_exps.get()
stats = super().communicate_mpoles(mpole_exps_host, return_stats)
mpole_exps[:] = mpole_exps_host
return stats
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.