From ed88555baed085b5ca5c1d1c5e80aaae26e469f3 Mon Sep 17 00:00:00 2001 From: "[6~" Date: Sun, 30 May 2021 15:15:28 -0500 Subject: [PATCH] Enable flake8-bugbear --- pytato/array.py | 2 +- pytato/scalar_expr.py | 12 +++++++++--- setup.cfg | 2 ++ test/testlib.py | 7 +++++-- 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 37036c5..e677b2a 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -763,7 +763,7 @@ class MatrixProduct(Array): elif self.x1.ndim == 2 and self.x2.ndim == 2: return (self.x1.shape[0], self.x2.shape[1]) - assert False + raise AssertionError() @property def dtype(self) -> np.dtype[Any]: diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 04ef3dc..9b00fc5 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -180,26 +180,32 @@ def get_dependencies(expression: Any, return frozenset(dep.name for dep in mapper(expression)) -def substitute(expression: Any, variable_assigments: Mapping[str, Any]) -> Any: +def substitute(expression: Any, + variable_assigments: Optional[Mapping[str, Any]]) -> Any: """Perform variable substitution in an expression. :param expression: A scalar expression, or an expression derived from such (e.g., a tuple of scalar expressions) :param variable_assigments: A mapping from variable names to substitutions """ + if variable_assigments is None: + variable_assigments = {} + from pymbolic.mapper.substitutor import make_subst_func return SubstitutionMapper(make_subst_func(variable_assigments))(expression) -def evaluate(expression: Any, context: Mapping[str, Any] = {}) -> Any: +def evaluate(expression: Any, context: Optional[Mapping[str, Any]] = None) -> Any: """ Evaluates *expression* by substituting the variable values as provided in *context*. """ + if context is None: + context = {} return EvaluationMapper(context)(expression) -def distribute(expr: Any, parameters: Set[Any] = set(), +def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(), commutative: bool = True) -> Any: if commutative: return DistributeMapper(TermCollector(parameters))(expr) diff --git a/setup.cfg b/setup.cfg index 3c5bb21..4d2c827 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,6 +6,8 @@ inline-quotes = " docstring-quotes = """ multiline-quotes = """ +# enable-flake8-bugbear + [mypy] [mypy-pytato.transform] diff --git a/test/testlib.py b/test/testlib.py index f12c529..c26f344 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -1,4 +1,4 @@ -from typing import (Any, Dict) +from typing import Any, Dict, Optional import pyopencl as cl import numpy import pytato as pt @@ -51,7 +51,7 @@ class NumpyBasedEvaluator(Mapper): def assert_allclose_to_numpy(expr: Array, queue: cl.CommandQueue, - parameters: Dict[Placeholder, Any] = {}, + parameters: Optional[Dict[Placeholder, Any]] = None, rtol=1e-7): """ Raises an :class:`AssertionError`, if there is a discrepancy between *expr* @@ -60,6 +60,9 @@ def assert_allclose_to_numpy(expr: Array, queue: cl.CommandQueue, :arg queue: An instance of :class:`pyopencl.CommandQueue` to which the generated kernel must be enqueued. """ + if parameters is None: + parameters = {} + np_result = NumpyBasedEvaluator(numpy, parameters)(expr) prog = pt.generate_loopy(expr, cl_device=queue.device) -- GitLab