From 60b04940e37af517d0284188ddd43c3c87c57476 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 17 Jul 2024 13:21:09 -0500 Subject: [PATCH] LinearSystemOperator: switch to dataclass, add types --- sumpy/expansion/diff_op.py | 66 +++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 25 deletions(-) diff --git a/sumpy/expansion/diff_op.py b/sumpy/expansion/diff_op.py index e6d540c5..8d8700ba 100644 --- a/sumpy/expansion/diff_op.py +++ b/sumpy/expansion/diff_op.py @@ -20,14 +20,15 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from collections import namedtuple +from dataclasses import dataclass from pyrsistent import pmap from pytools import memoize from sumpy.tools import add_mi from itertools import accumulate import sumpy.symbolic as sym import logging -from typing import List +from typing import List, Mapping, Sequence +import sympy as sp logger = logging.getLogger(__name__) @@ -41,9 +42,28 @@ Differential operator interface .. autofunction:: as_scalar_pde """ -DerivativeIdentifier = namedtuple("DerivativeIdentifier", ["mi", "vec_idx"]) + +@dataclass(frozen=True) +class DerivativeIdentifier: + """ + .. autoattribute:: mi + .. autoattribute: vec_idx + """ + + mi: tuple[int, ...] + """ + Multi-index of the derivative being taken, a tuple with a number of entries + corresponding to the dimension. + """ + + vec_idx: int + """ + In a PDE system of :math:`n` variables, an integer between :math:`0` and :math:`n-1` + indicating which variable is being differentiated. + """ +@dataclass(frozen=True, eq=True) class LinearPDESystemOperator: r""" Represents a constant-coefficient linear differential operator of a @@ -52,21 +72,17 @@ class LinearPDESystemOperator: :class:`DerivativeIdentifier` to the coefficient. This object is immutable. Optionally supports a time variable as the last variable in the multi-index of the :class:`DerivativeIdentifier`. - """ - def __init__(self, dim, *eqs): - """ - :arg dim: Number of spatial dimensions of the LinearPDESystemOperator - :arg eqs: A list of dictionaries mapping a :class:`DerivativeIdentifier` - to a coefficient. - """ - self.dim = dim - self.eqs = tuple(eqs) - def __eq__(self, other): - return self.dim == other.dim and self.eqs == other.eqs + .. autoattribute:: dim + .. autoattribute:: eqs + + .. autoattribute:: order + .. autoattribute:: total_dims + .. automethod:: to_sym + """ - def __hash__(self): - return hash((self.dim, self.eqs)) + dim: int + eqs: Sequence[Mapping[DerivativeIdentifier, sp.Expr]] @property def order(self): @@ -82,7 +98,7 @@ class LinearPDESystemOperator: for k, v in eq.items(): deriv_ident_to_coeff[k] = v * param eqs.append(pmap(deriv_ident_to_coeff)) - return LinearPDESystemOperator(self.dim, *eqs) + return LinearPDESystemOperator(self.dim, tuple(eqs)) __rmul__ = __mul__ @@ -98,7 +114,7 @@ class LinearPDESystemOperator: else: res[k] = v eqs.append(pmap(res)) - return LinearPDESystemOperator(self.dim, *eqs) + return LinearPDESystemOperator(self.dim, tuple(eqs)) __radd__ = __add__ @@ -112,7 +128,7 @@ class LinearPDESystemOperator: item = self.eqs.__getitem__(idx) if not isinstance(item, tuple): item = (item,) - return LinearPDESystemOperator(self.dim, *item) + return LinearPDESystemOperator(self.dim, tuple(item)) @property def total_dims(self): @@ -306,7 +322,7 @@ def as_scalar_pde(pde: LinearPDESystemOperator, comp_idx: int) \ def laplacian(diff_op): dim = diff_op.dim empty = [pmap()] * len(diff_op.eqs) - res = LinearPDESystemOperator(dim, *empty) + res = LinearPDESystemOperator(dim, empty) for j in range(dim): mi = [0]*diff_op.total_dims mi[j] = 2 @@ -322,7 +338,7 @@ def diff(diff_op, mi): new_mi = add_mi(deriv_ident.mi, mi) res[DerivativeIdentifier(new_mi, deriv_ident.vec_idx)] = v eqs.append(pmap(res)) - return LinearPDESystemOperator(diff_op.dim, *eqs) + return LinearPDESystemOperator(diff_op.dim, tuple(eqs)) def divergence(diff_op): @@ -343,7 +359,7 @@ def gradient(diff_op): mi = [0]*diff_op.total_dims mi[i] = 1 eqs.append(diff(diff_op, tuple(mi)).eqs[0]) - return LinearPDESystemOperator(dim, *eqs) + return LinearPDESystemOperator(dim, tuple(eqs)) def curl(pde): @@ -361,7 +377,7 @@ def curl(pde): diff(pde[(i+1) % 3], mis[(i+2) % 3]) eqs.append(new_pde.eqs[0]) - return LinearPDESystemOperator(pde.dim, *eqs) + return LinearPDESystemOperator(pde.dim, tuple(eqs)) def concat(*ops): @@ -373,7 +389,7 @@ def concat(*ops): eqs = list(ops[0].eqs) for op in ops[1:]: eqs.extend(list(op.eqs)) - return LinearPDESystemOperator(dim, *eqs) + return LinearPDESystemOperator(dim, eqs) def make_identity_diff_op(ninput, noutput=1, time_dependent=False): @@ -391,4 +407,4 @@ def make_identity_diff_op(ninput, noutput=1, time_dependent=False): else: mi = tuple([0]*ninput) eqs = [pmap({DerivativeIdentifier(mi, i): 1}) for i in range(noutput)] - return LinearPDESystemOperator(ninput, *eqs) + return LinearPDESystemOperator(ninput, eqs) -- GitLab