diff --git a/pytential/symbolic/primitives.py b/pytential/symbolic/primitives.py index fb0eb380e6ebf1b090dd9e64c3c11ef20af78b11..0a2d28ea1380823ed88e54f2d79a7ebfcf10c3ed 100644 --- a/pytential/symbolic/primitives.py +++ b/pytential/symbolic/primitives.py @@ -324,6 +324,19 @@ def mean_curvature(where): class NodeSum(Expression): """Implements a global sum over all discretization nodes.""" + def __new__(cls, operand): + # If the constructor is handed a multivector object, return an + # object array of the operator applied to each of the + # coefficients in the multivector. + + if isinstance(operand, (np.ndarray, MultiVector)): + def make_op(operand_i): + return cls(operand_i) + + return componentwise(make_op, operand) + else: + return Expression.__new__(cls) + def __init__(self, operand): self.operand = operand @@ -432,7 +445,7 @@ class IntG(Expression): # object array of the operator applied to each of the # coefficients in the multivector. - if isinstance(density, MultiVector): + if isinstance(density, (np.ndarray, MultiVector)): def make_op(operand_i): return cls(kernel, operand_i, *args, **kwargs)