diff --git a/pytato/partition.py b/pytato/partition.py index c685fdd3599f4868f3633912049e2500750e3cdc..7da3b80379a1af2e2c8905802912ae9c05ec6458 100644 --- a/pytato/partition.py +++ b/pytato/partition.py @@ -414,7 +414,7 @@ def _check_partition_disjointness(partition: GraphPartition) -> None: # {{{ generate_code_for_partition def generate_code_for_partition(partition: GraphPartition) \ - -> Dict[PartId, BoundProgram]: + -> Mapping[PartId, BoundProgram]: """Return a mapping of partition identifiers to their :class:`pytato.target.BoundProgram`.""" from pytato import generate_loopy diff --git a/pytato/target/__init__.py b/pytato/target/__init__.py index f9f1f2d2806d581ed05fc25414bad128c572df3d..6f15998b7bb8f86a78af3be5485af2cd1c43096a 100644 --- a/pytato/target/__init__.py +++ b/pytato/target/__init__.py @@ -37,19 +37,9 @@ from typing import Any, Mapping class Target: - """An abstract code generation target. - - .. automethod:: bind_program """ - - def bind_program(self, program: Any, - bound_arguments: Mapping[str, Any]) -> BoundProgram: - """Create a :class:`BoundProgram` for this code generation target. - - :param program: the :mod:`loopy` program - :param bound_arguments: a mapping from argument names to outputs - """ - raise NotImplementedError + An abstract code generation target. + """ @dataclass(init=True, repr=False, eq=False) diff --git a/pytato/target/loopy/__init__.py b/pytato/target/loopy/__init__.py index d800216a68ee5eb11489942e978795307c586a84..b893ffdf1b7c09c1c8fb25cfaec0f0c5e35f23d6 100644 --- a/pytato/target/loopy/__init__.py +++ b/pytato/target/loopy/__init__.py @@ -37,7 +37,7 @@ __doc__ = """ import sys from dataclasses import dataclass -from typing import Any, Mapping, Optional, Union, Callable +from typing import Any, Mapping, Optional, Callable from pytato.target import Target, BoundProgram @@ -55,12 +55,25 @@ class LoopyTarget(Target): """An :mod:`loopy` target. .. automethod:: get_loopy_target + + .. automethod:: bind_program """ def get_loopy_target(self) -> "loopy.TargetBase": """Return the corresponding :mod:`loopy` target.""" raise NotImplementedError + def bind_program(self, program: loopy.TranslationUnit, + bound_arguments: Mapping[str, Any]) -> BoundProgram: + """ + Create a :class:`pytato.target.BoundProgram` for this code generation + target. + + :param program: the :mod:`loopy` program + :param bound_arguments: a mapping from argument names to outputs + """ + raise NotImplementedError + class LoopyPyOpenCLTarget(LoopyTarget): """A :mod:`pyopencl` code generation target. @@ -81,11 +94,11 @@ class LoopyPyOpenCLTarget(LoopyTarget): import loopy as lp return lp.PyOpenCLTarget(self.device) - def bind_program(self, program: Union["loopy.Program", "loopy.LoopKernel"], - bound_arguments: Mapping[str, Any]) -> BoundProgram: + def bind_program(self, program: loopy.TranslationUnit, + bound_arguments: Mapping[str, Any]) -> BoundProgram: return BoundPyOpenCLProgram(program=program, - bound_arguments=bound_arguments, - target=self) + bound_arguments=bound_arguments, + target=self) @dataclass(init=True, repr=False, eq=False) diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 8055c12f5b5d9d5c78989ca3d68f51e9e9dd9648..637d81261e24f94f9e4b4d6b54e75e794fd9bd47 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -923,6 +923,8 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]], if cl_device is not None: raise TypeError("may not pass both 'target' and 'cl_device'") + assert isinstance(target, LoopyTarget) + preproc_result = preprocess(orig_outputs, target) outputs = preproc_result.outputs compute_order = preproc_result.compute_order