diff --git a/doc/reference.rst b/doc/reference.rst index bf3ddc86480c9eeab238bfc3907096569f661c2d..8b1a4d6f9f4069f6b24eeed0590ad4743dbd9ff7 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -203,6 +203,15 @@ Manipulating Instructions .. autofunction:: add_dependency +Argument types +-------------- + +.. autofunction:: add_argument_dtypes + +.. autofunction:: infer_argument_dtypes + +.. autofunction:: add_and_infer_argument_dtypes + Finishing up ------------ diff --git a/loopy/__init__.py b/loopy/__init__.py index 3545106917b230070e1ef91f9ae6db54059f9ee7..668e4ddaa92b18fe654753ba4c396fcd86f6697f 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -60,7 +60,9 @@ from loopy.kernel.data import ( Instruction) from loopy.kernel import LoopKernel -from loopy.kernel.tools import get_dot_dependency_graph +from loopy.kernel.tools import ( + get_dot_dependency_graph, add_argument_dtypes, + infer_argument_dtypes, add_and_infer_argument_dtypes) from loopy.kernel.creation import make_kernel from loopy.reduction import register_reduction_parser from loopy.subst import extract_subst, expand_subst @@ -83,7 +85,10 @@ __all__ = [ "default_preamble_generator", "make_kernel", "register_reduction_parser", - "get_dot_dependency_graph", + + "get_dot_dependency_graph", "add_argument_dtypes", + "infer_argument_dtypes", "add_and_infer_argument_dtypes", + "preprocess_kernel", "realize_reduction", "generate_loop_schedules", "generate_code", diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 02ce072459e92bacaadefeeb26a6ae0a352f8ead..2d0bbb40190cfddb5597d360507210b82a5bb863 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -28,18 +28,19 @@ THE SOFTWARE. import numpy as np -from pytools import Record, memoize_method -import islpy as isl from islpy import dim_type -import re - # {{{ add and infer argument dtypes def add_argument_dtypes(knl, dtype_dict): + """Specify remaining unspecified argument types. + + :arg dtype_dict: a mapping from argument names to :class:`numpy.dtype` + instances + """ dtype_dict = dtype_dict.copy() new_args = [] @@ -128,6 +129,10 @@ def infer_argument_dtypes(knl): else: return knl +def add_and_infer_argument_dtypes(knl, dtype_dict): + knl = add_argument_dtypes(knl, dtype_dict) + return infer_argument_dtypes(knl) + # }}} # {{{ find_all_insn_inames fixed point iteration