From 23afa5980f109ac4306b31d92c14c49000cd16ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Kl=C3=B6ckner?= Date: Thu, 26 Aug 2021 17:00:19 -0500 Subject: [PATCH] Revert "validate arguments passed to BoundPyopenclProgram" This reverts commit 87c793c2afe861222648f68416ebfbe9937d1147. --- pytato/target/__init__.py | 11 ++--------- pytato/target/loopy/__init__.py | 21 ++------------------- pytato/target/loopy/codegen.py | 7 ++----- 3 files changed, 6 insertions(+), 33 deletions(-) diff --git a/pytato/target/__init__.py b/pytato/target/__init__.py index fa4b0a9..f9f1f2d 100644 --- a/pytato/target/__init__.py +++ b/pytato/target/__init__.py @@ -33,7 +33,7 @@ Code Generation Targets """ from dataclasses import dataclass -from typing import Any, Mapping, FrozenSet +from typing import Any, Mapping class Target: @@ -43,9 +43,7 @@ class Target: """ def bind_program(self, program: Any, - bound_arguments: Mapping[str, Any], - valid_arguments: FrozenSet[str] - ) -> BoundProgram: + bound_arguments: Mapping[str, Any]) -> BoundProgram: """Create a :class:`BoundProgram` for this code generation target. :param program: the :mod:`loopy` program @@ -71,10 +69,6 @@ class BoundProgram: A map from names to pre-bound kernel arguments. - .. attribute:: valid_arguments - - A :class:`frozenset` of argument names that could be passed in. - .. method:: __call__ It is expected that every concrete subclass of this class @@ -85,7 +79,6 @@ class BoundProgram: program: Any bound_arguments: Mapping[str, Any] target: Target - valid_arguments: FrozenSet[str] def __call__(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError diff --git a/pytato/target/loopy/__init__.py b/pytato/target/loopy/__init__.py index 6b06543..976a5f6 100644 --- a/pytato/target/loopy/__init__.py +++ b/pytato/target/loopy/__init__.py @@ -37,10 +37,9 @@ __doc__ = """ import sys from dataclasses import dataclass -from typing import Any, Mapping, Optional, Union, Callable, FrozenSet +from typing import Any, Mapping, Optional, Union, Callable from pytato.target import Target, BoundProgram -from pytools import memoize_method import loopy @@ -83,12 +82,9 @@ class LoopyPyOpenCLTarget(LoopyTarget): return lp.PyOpenCLTarget(self.device) def bind_program(self, program: Union["loopy.Program", "loopy.LoopKernel"], - bound_arguments: Mapping[str, Any], - valid_arguments: FrozenSet[str] - ) -> BoundProgram: + bound_arguments: Mapping[str, Any]) -> BoundProgram: return BoundPyOpenCLProgram(program=program, bound_arguments=bound_arguments, - valid_arguments=valid_arguments, target=self) @@ -104,7 +100,6 @@ class BoundPyOpenCLProgram(BoundProgram): def copy(self, *, program: Optional[loopy.TranslationUnit] = None, bound_arguments: Optional[Mapping[str, Any]] = None, - valid_arguments: Optional[FrozenSet[str]] = None, target: Optional[Target] = None ) -> BoundPyOpenCLProgram: if program is None: @@ -113,15 +108,11 @@ class BoundPyOpenCLProgram(BoundProgram): if bound_arguments is None: bound_arguments = self.bound_arguments - if valid_arguments is None: - valid_arguments = self.valid_arguments - if target is None: target = self.target return BoundPyOpenCLProgram(program=program, bound_arguments=bound_arguments, - valid_arguments=valid_arguments, target=target) def with_transformed_program(self, f: Callable[[loopy.TranslationUnit], @@ -132,12 +123,6 @@ class BoundPyOpenCLProgram(BoundProgram): """ return self.copy(program=f(self.program)) - @memoize_method - def _validate_args(self, passed_arg_names: FrozenSet[str]) -> None: - if not (passed_arg_names <= self.valid_arguments): - raise ValueError("Unexpected arguments passed:" - f" '{passed_arg_names - self.valid_arguments}'") - def __call__(self, queue: "pyopencl.CommandQueue", # type: ignore *args: Any, **kwargs: Any) -> Any: """Convenience function for launching a :mod:`pyopencl` computation.""" @@ -146,8 +131,6 @@ class BoundPyOpenCLProgram(BoundProgram): raise ValueError("Got arguments that were previously bound: " f"{set(kwargs.keys()) & set(self.bound_arguments.keys())}.") - self._validate_args(frozenset(kwargs.keys())) - updated_kwargs = dict(self.bound_arguments) updated_kwargs.update(kwargs) if not isinstance(self. program, loopy.LoopKernel): diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 92534bd..1c14047 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -32,7 +32,7 @@ import re import pytato.scalar_expr as scalar_expr import pymbolic.primitives as prim from pymbolic import var -import typing + from typing import (Union, Optional, Mapping, Dict, Tuple, FrozenSet, Set, Any, List) @@ -881,10 +881,7 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]], return target.bind_program( program=program, - bound_arguments=preproc_result.bound_arguments, - valid_arguments=typing.cast(FrozenSet[str], - frozenset({inp.name - for inp in ing(outputs)}))) + bound_arguments=preproc_result.bound_arguments) # }}} -- GitLab