diff --git a/dagrt/function_registry.py b/dagrt/function_registry.py index 5082151604d98103f64128b07582622a6865cb2c..e084e7c05386ab18cc6b7326fd8c874cb36e0256 100644 --- a/dagrt/function_registry.py +++ b/dagrt/function_registry.py @@ -32,6 +32,48 @@ from dagrt.codegen.data import ( NoneType = type(None) +__doc__ = """ +The function registry is used by targets to resolve external functions and +invoke user-specified code, including but not limited to ODE right-hand sides. + +.. autoclass:: Function +.. autoclass:: FunctionRegistry +.. autoclass:: FunctionNotFound + +.. data:: base_function_registry + +The default function registry, containing all the built-in functions (see +:ref:`built-ins`). + +Registering new functions +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: register_ode_rhs +.. autofunction:: register_function + +.. _built-ins: + +Built-ins +^^^^^^^^^ + +The built-in functions are listed below. This also serves as their language +documentation. + +.. autoclass:: Norm1 +.. autoclass:: Norm2 +.. autoclass:: NormInf +.. autoclass:: DotProduct +.. autoclass:: Len +.. autoclass:: IsNaN +.. autoclass:: Array_ +.. autoclass:: MatMul +.. autoclass:: Transpose +.. autoclass:: LinearSolve +.. autoclass:: SVD +.. autoclass:: Print + +""" + # {{{ function @@ -51,8 +93,21 @@ class Function(RecordWithoutPickling): assigned. .. attribute:: identifier + + The name of the function. + .. attribute:: arg_names + + The names of the arguments to the function. + .. attribute:: default_dict + + A dictionary mapping argument names to default values. + + .. automethod:: get_result_kinds + .. automethod:: register_codegen + .. automethod:: get_codegen + .. automethod:: resolve_args """ def __init__(self, language_to_codegen=None, **kwargs): @@ -84,6 +139,12 @@ class Function(RecordWithoutPickling): raise NotImplementedError() def register_codegen(self, language, codegen_function): + """Return a copy of *self* with *codegen_function* + registered as a code generator for *language*. + + The interface for *codegen_function* depends on the + code generator being used. + """ new_language_to_codegen = self.language_to_codegen.copy() if language in new_language_to_codegen: @@ -95,6 +156,8 @@ class Function(RecordWithoutPickling): return self.copy(language_to_codegen=new_language_to_codegen) def get_codegen(self, language): + """Return the code generator for *language*. + """ try: return self.language_to_codegen[language] except KeyError: @@ -103,6 +166,13 @@ class Function(RecordWithoutPickling): % (self.identifier, language)) def resolve_args(self, arg_dict): + """Resolve positional and keyword arguments to an argument list. + + See also :func:`dagrt.utils.resolve_args`. + + :arg arg_dict: a dictionary mapping numbers (for positional arguments) + or identifiers (for keyword arguments) to values + """ from dagrt.utils import resolve_args return resolve_args(self.arg_names, self.default_dict, arg_dict) @@ -124,6 +194,13 @@ class FixedResultKindsFunction(Function): # {{{ function registry class FunctionRegistry(RecordWithoutPickling): + """ + .. automethod:: register + .. automethod:: __getitem__ + .. automethod:: __contains__ + .. automethod:: register_codegen + .. automethod:: get_codegen + """ def __init__(self, id_to_function=None): if id_to_function is None: id_to_function = {} @@ -143,6 +220,10 @@ class FunctionRegistry(RecordWithoutPickling): return self.copy(id_to_function=new_id_to_function) def __getitem__(self, function_id): + """Return the :class:`Function` with identifier *function_id*. + + :raises FunctionNotFound: when *function_id* was not found + """ try: return self.id_to_function[function_id] except KeyError: @@ -200,24 +281,32 @@ class _NormBase(Function): return (Scalar(is_real_valued=True),) -class _Norm1(_NormBase): - """``norm_1(x)`` returns the 1-norm of *x*.""" +class Norm1(_NormBase): + """``norm_1(x)`` returns the 1-norm of *x*. + *x* is a user type or array. + """ identifier = "norm_1" -class _Norm2(_NormBase): - """``norm_2(x)`` returns the 2-norm of *x*.""" +class Norm2(_NormBase): + """``norm_2(x)`` returns the 2-norm of *x*. + *x* is a user type or array. + """ identifier = "norm_2" -class _NormInf(_NormBase): - """``norm_inf(x)`` returns the infinity-norm of *x*.""" +class NormInf(_NormBase): + """``norm_inf(x)`` returns the infinity-norm of *x*. + *x* is a user type or array. + """ identifier = "norm_inf" -class _DotProduct(Function): - """dot_product(x, y)`` return the dot product of *x* and *y*. The +class DotProduct(Function): + """``dot_product(x, y)`` return the dot product of *x* and *y*. The complex conjugate of *x* is taken first, if applicable. + *x* and *y* are either arrays (that must be of the same length) or the + same user type. """ result_names = ("result",) @@ -236,8 +325,10 @@ class _DotProduct(Function): return (Scalar(is_real_valued=False),) -class _Len(Function): - """``len(x)`` returns the number of degrees of freedom in *x* """ +class Len(Function): + """``len(x)`` returns the number of degrees of freedom in *x*. + *x* is a user type or array. + """ result_names = ("result",) identifier = "len" @@ -253,8 +344,10 @@ class _Len(Function): return (Scalar(is_real_valued=True),) -class _IsNaN(Function): - """``isnan(x)`` returns True if there are any NaNs in *x*""" +class IsNaN(Function): + """``isnan(x)`` returns True if and only if there are any NaNs in *x*. + *x* is a user type, scalar, or array. + """ result_names = ("result",) identifier = "isnan" @@ -270,9 +363,9 @@ class _IsNaN(Function): return (Boolean(),) -class _Array(Function): - """``array(n)`` returns an empty array with n entries in it. - n must be an integer. +class Array_(Function): # noqa + """``array(n)`` returns an empty array with *n* entries in it. + *n* must be an integer. """ result_names = ("result",) @@ -289,10 +382,10 @@ class _Array(Function): return (Array(is_real_valued=True),) -class _MatMul(Function): +class MatMul(Function): """``matmul(a, b, a_cols, b_cols)`` returns a 1D array containing the - matrix resulting from multiplying the arrays a and b (both interpreted - as matrices, with a number of columns *a_cols* and *b_cols* respectively) + matrix resulting from multiplying the arrays *a* and *b* (both interpreted + as matrices, with a number of columns *a_cols* and *b_cols* respectively). """ result_names = ("result",) @@ -321,9 +414,10 @@ class _MatMul(Function): return (Array(is_real_valued),) -class _Transpose(Function): +class Transpose(Function): """``transpose(a, a_cols)`` returns a 1D array containing the - matrix resulting from transposing the array a + matrix resulting from transposing the array *a* (interpreted + as a matrix with *a_cols* columns). """ result_names = ("result",) @@ -348,10 +442,11 @@ class _Transpose(Function): return (Array(is_real_valued),) -class _LinearSolve(Function): +class LinearSolve(Function): """``linear_solve(a, b, a_cols, b_cols)`` returns a 1D array containing the - matrix resulting from multiplying the matrix inverse of a by b (both interpreted - as matrices, with a number of columns *a_cols* and *b_cols* respectively) + matrix resulting from multiplying the matrix inverse of *a* by *b*, both + interpreted as matrices, with a number of columns *a_cols* and *b_cols* + respectively. """ result_names = ("result",) @@ -380,7 +475,7 @@ class _LinearSolve(Function): return (Array(is_real_valued),) -class _SVD(Function): +class SVD(Function): """``SVD(a, a_cols)`` returns a 2D array ``u``, a 1D array ``sigma``, and a 2D array ``vt``, representing the (reduced) SVD of ``a``. """ @@ -407,7 +502,7 @@ class _SVD(Function): return (Array(is_real_valued), Array(is_real_valued), Array(is_real_valued)) -class _Print(Function): +class Print(Function): """``print(arg)`` prints the given operand to standard output. Returns an integer that may be ignored. """ @@ -443,18 +538,18 @@ def _make_bfr(): bfr = FunctionRegistry() for func, py_pattern in [ - (_Norm1(), "self._builtin_norm_1({args})"), - (_Norm2(), "self._builtin_norm_2({args})"), - (_NormInf(), "self._builtin_norm_inf({args})"), - (_DotProduct(), "{numpy}.vdot({args})"), - (_Len(), "{numpy}.size({args})"), - (_IsNaN(), "{numpy}.isnan({args})"), - (_Array(), "self._builtin_array({args})"), - (_MatMul(), "self._builtin_matmul({args})"), - (_Transpose(), "self._builtin_transpose({args})"), - (_LinearSolve(), "self._builtin_linear_solve({args})"), - (_Print(), "self._builtin_print({args})"), - (_SVD(), "self._builtin_svd({args})"), + (Norm1(), "self._builtin_norm_1({args})"), + (Norm2(), "self._builtin_norm_2({args})"), + (NormInf(), "self._builtin_norm_inf({args})"), + (DotProduct(), "{numpy}.vdot({args})"), + (Len(), "{numpy}.size({args})"), + (IsNaN(), "{numpy}.isnan({args})"), + (Array_(), "self._builtin_array({args})"), + (MatMul(), "self._builtin_matmul({args})"), + (Transpose(), "self._builtin_transpose({args})"), + (LinearSolve(), "self._builtin_linear_solve({args})"), + (Print(), "self._builtin_print({args})"), + (SVD(), "self._builtin_svd({args})"), ]: bfr = bfr.register(func) @@ -466,23 +561,23 @@ def _make_bfr(): import dagrt.codegen.fortran as f - bfr = bfr.register_codegen(_Norm2.identifier, "fortran", + bfr = bfr.register_codegen(Norm2.identifier, "fortran", f.codegen_builtin_norm_2) - bfr = bfr.register_codegen(_Len.identifier, "fortran", + bfr = bfr.register_codegen(Len.identifier, "fortran", f.codegen_builtin_len) - bfr = bfr.register_codegen(_IsNaN.identifier, "fortran", + bfr = bfr.register_codegen(IsNaN.identifier, "fortran", f.codegen_builtin_isnan) - bfr = bfr.register_codegen(_Array.identifier, "fortran", + bfr = bfr.register_codegen(Array_.identifier, "fortran", f.builtin_array) - bfr = bfr.register_codegen(_MatMul.identifier, "fortran", + bfr = bfr.register_codegen(MatMul.identifier, "fortran", f.builtin_matmul) - bfr = bfr.register_codegen(_Transpose.identifier, "fortran", + bfr = bfr.register_codegen(Transpose.identifier, "fortran", f.builtin_transpose) - bfr = bfr.register_codegen(_LinearSolve.identifier, "fortran", + bfr = bfr.register_codegen(LinearSolve.identifier, "fortran", f.builtin_linear_solve) - bfr = bfr.register_codegen(_SVD.identifier, "fortran", + bfr = bfr.register_codegen(SVD.identifier, "fortran", f.builtin_svd) - bfr = bfr.register_codegen(_Print.identifier, "fortran", + bfr = bfr.register_codegen(Print.identifier, "fortran", f.builtin_print) return bfr @@ -547,6 +642,33 @@ def register_ode_rhs( function_registry, output_type_id, identifier=None, input_type_ids=None, input_names=None): + """Register a function as an ODE right-hand side. + + Functions registered through this call have the following characteristics. + First, there is a single return value of the user type whose type identifier + is *output_type_id*. Second, the function has as its first argument a scalar + named *t*. Last, the remaining argument list to the function consists of + user type values. + + For example, considering the ODE :math:`y' = f(t, y)`, the following call + registers a right-hand side function with name *f* and user type *y*:: + + freg = register_ode_rhs(freg, "y", identifier="f") + + :arg function_registry: the base function registry + :arg output_type_id: a string, the user type ID returned by the call. + :arg identifier: the full name of the function. If not provided, defaults + to * + output_type_id*. + :arg input_type_ids: a tuple of strings, the identifiers of the user types + which are the arguments to the right-hand side function. An automatically + added *t* argument occurs before these arguments. If not provided, + defaults to *(output_type_id,)*. + :arg input_names: a tuple of strings, the names of the inputs. If not provided, + defaults to *input_type_ids*. + + :returns: a new :class:`FunctionRegistry` + + """ if identifier is None: identifier = ""+output_type_id @@ -566,6 +688,19 @@ def register_function( default_dict=None, result_names=(), result_kinds=()): + """Register a function returning output(s) of fixed kind. + + :arg function_registry: the base :class:`FunctionRegistry` + :arg identifier: a string, the function identifier + :arg arg_names: a list of strings, the names of the arguments + :arg default_dict: a dictionary mapping argument names to default + values + :arg result_names: a list of strings, the names of the output(s) + :arg result_kinds: a list of :class:`dagrt.codegen.data.SymbolKinds`, + the kinds of the output(s) + + :returns: a new :class:`FunctionRegistry` + """ return function_registry.register( FixedResultKindsFunction( diff --git a/doc/reference.rst b/doc/reference.rst index 6d5e74daa7a967374689b57b3c3c50d46078c8c2..61a432c3202147afd0a64c5342efaba8c01d3de4 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -28,11 +28,7 @@ Fortran Function registry ~~~~~~~~~~~~~~~~~ -The function registry is used by targets to register external -functions and customized function call code. - .. automodule:: dagrt.function_registry - :members: Transformations ~~~~~~~~~~~~~~~