diff --git a/pytential/symbolic/primitives.py b/pytential/symbolic/primitives.py index 3a297fd1f320e28d3f2308ed46e79c878e97bb90..94a8b125cb4f15b8196067fdd1ab2bd504bf1d16 100644 --- a/pytential/symbolic/primitives.py +++ b/pytential/symbolic/primitives.py @@ -328,6 +328,8 @@ class DiscretizationProperty(Expression): further arguments. """ + init_arg_names = ("where",) + def __init__(self, where=None): """ :arg where: |where-blurb| @@ -348,6 +350,9 @@ class QWeight(DiscretizationProperty): class NodeCoordinateComponent(DiscretizationProperty): + + init_arg_names = ("ambient_axis", "where") + def __init__(self, ambient_axis, where=None): """ :arg where: |where-blurb| @@ -378,12 +383,14 @@ class NumReferenceDerivative(DiscretizationProperty): reference coordinates. """ - def __new__(cls, ref_axes, operand, where=None): + init_arg_names = ("ref_axes", "operand", "where") + + def __new__(cls, ref_axes=None, operand=None, where=None): # If the constructor is handed a multivector object, return an # object array of the operator applied to each of the # coefficients in the multivector. - if isinstance(operand, (np.ndarray)): + if isinstance(operand, np.ndarray): def make_op(operand_i): return cls(ref_axes, operand_i, where=where) @@ -750,7 +757,10 @@ def _scaled_max_curvature(ambient_dim, dim=None, where=None): # {{{ operators class SingleScalarOperandExpression(Expression): - def __new__(cls, operand): + + init_arg_names = ("operand",) + + def __new__(cls, operand=None): # If the constructor is handed a multivector object, return an # object array of the operator applied to each of the # coefficients in the multivector. @@ -792,7 +802,10 @@ def integral(ambient_dim, dim, operand, where=None): class SingleScalarOperandExpressionWithWhere(Expression): - def __new__(cls, operand, where=None): + + init_arg_names = ("operand", "where") + + def __new__(cls, operand=None, where=None): # If the constructor is handed a multivector object, return an # object array of the operator applied to each of the # coefficients in the multivector. @@ -842,6 +855,8 @@ class Ones(Expression): discretization. """ + init_arg_names = ("where",) + def __init__(self, where=None): self.where = where @@ -870,6 +885,9 @@ def mean(ambient_dim, dim, operand, where=None): class IterativeInverse(Expression): + + init_arg_names = ("expression", "rhs", "variable_name", "extra_vars", "where") + def __init__(self, expression, rhs, variable_name, extra_vars={}, where=None): self.expression = expression @@ -982,7 +1000,10 @@ class IntG(Expression): where :math:`\sigma` is *density*. """ - def __new__(cls, kernel, density, *args, **kwargs): + init_arg_names = ("kernel", "density", "qbx_forced_limit", "source", "target", + "kernel_arguments") + + def __new__(cls, kernel=None, density=None, *args, **kwargs): # If the constructor is handed a multivector object, return an # object array of the operator applied to each of the # coefficients in the multivector. @@ -1024,8 +1045,8 @@ class IntG(Expression): :arg kernel_arguments: A dictionary mapping named :class:`sumpy.kernel.Kernel` arguments (see :meth:`sumpy.kernel.Kernel.get_args` - and :meth:`sumpy.kernel.Kernel.get_source_args` - to expressions that determine them) + and :meth:`sumpy.kernel.Kernel.get_source_args`) + to expressions that determine them :arg source: The symbolic name of the source discretization. This name is bound to a concrete :class:`pytential.source.LayerPotentialSourceBase` @@ -1114,6 +1135,11 @@ class IntG(Expression): self.source, self.target, hashable_kernel_args(self.kernel_arguments)) + def __setstate__(self, state): + # Overwrite pymbolic.Expression.__setstate__ + assert len(self.init_arg_names) == len(state), type(self) + self.__init__(*state) + mapper_method = intern("map_int_g") @@ -1130,7 +1156,7 @@ def _insert_source_derivative_into_kernel(kernel): kernel, dir_vec_name=_DIR_VEC_NAME) else: return kernel.replace_inner_kernel( - _insert_source_derivative_into_kernel(kernel.kernel)) + _insert_source_derivative_into_kernel(kernel.inner_kernel)) def _get_dir_vec(dsource, ambient_dim): diff --git a/test/test_symbolic.py b/test/test_symbolic.py index 8b2e7cb409766d522447cf2da6e1f1d81af03bec..2f5633d34025210c9e37e761db1596a0201470ce 100644 --- a/test/test_symbolic.py +++ b/test/test_symbolic.py @@ -169,8 +169,38 @@ def test_tangential_onb(ctx_factory): # }}} +# {{{ test_expr_pickling + +def test_expr_pickling(): + from sumpy.kernel import LaplaceKernel, AxisTargetDerivative + import pickle + import pytential + + ops_for_testing = [ + pytential.sym.d_dx( + 2, + pytential.sym.D( + LaplaceKernel(2), pytential.sym.var("sigma"), qbx_forced_limit=-2 + ) + ), + pytential.sym.D( + AxisTargetDerivative(0, LaplaceKernel(2)), + pytential.sym.var("sigma"), + qbx_forced_limit=-2 + ) + ] + + for op in ops_for_testing: + pickled_op = pickle.dumps(op) + after_pickle_op = pickle.loads(pickled_op) + + assert op == after_pickle_op + +# }}} + + # You can test individual routines by typing -# $ python test_tools.py 'test_routine()' +# $ python test_symbolic.py 'test_routine()' if __name__ == "__main__": import sys