From 0903af284487b6ade13955e77de70783c081fdc9 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 27 Jan 2025 13:08:54 -0600
Subject: [PATCH] Ruff: enable SIM rules

---
 arraycontext/__init__.py             |  2 +-
 arraycontext/container/arithmetic.py | 25 ++++++++++++-------------
 arraycontext/container/traversal.py  |  7 ++-----
 arraycontext/impl/pytato/__init__.py | 12 +++++-------
 doc/make_numpy_coverage_table.py     | 27 +++++++++++++--------------
 pyproject.toml                       |  1 +
 test/test_arraycontext.py            |  5 +----
 7 files changed, 35 insertions(+), 44 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 674a229..74adae9 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -186,7 +186,7 @@ _depr_name_to_replacement_and_obj = {
 
 
 def __getattr__(name):
-    replacement_and_obj = _depr_name_to_replacement_and_obj.get(name, None)
+    replacement_and_obj = _depr_name_to_replacement_and_obj.get(name)
     if replacement_and_obj is not None:
         replacement, obj, year = replacement_and_obj
         from warnings import warn
diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index 22572dc..73dee6d 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -373,19 +373,18 @@ def with_container_arithmetic(
         cls_has_array_context_attr: bool | None = _cls_has_array_context_attr
         bcast_actx_array_type: bool | None = _bcast_actx_array_type
 
-        if cls_has_array_context_attr is None:
-            if hasattr(cls, "array_context"):
-                raise TypeError(
-                        f"{cls} has an 'array_context' attribute, but it does not "
-                        "set '_cls_has_array_context_attr' to *True* when calling "
-                        "with_container_arithmetic. This is being interpreted "
-                        "as '.array_context' being permitted to fail "
-                        "with an exception, which is no longer allowed. "
-                        f"If {cls.__name__}.array_context will not fail, pass "
-                        "'_cls_has_array_context_attr=True'. "
-                        "If you do not want container arithmetic to make "
-                        "use of the array context, set "
-                        "'_cls_has_array_context_attr=False'.")
+        if cls_has_array_context_attr is None and hasattr(cls, "array_context"):
+            raise TypeError(
+                    f"{cls} has an 'array_context' attribute, but it does not "
+                    "set '_cls_has_array_context_attr' to *True* when calling "
+                    "with_container_arithmetic. This is being interpreted "
+                    "as '.array_context' being permitted to fail "
+                    "with an exception, which is no longer allowed. "
+                    f"If {cls.__name__}.array_context will not fail, pass "
+                    "'_cls_has_array_context_attr=True'. "
+                    "If you do not want container arithmetic to make "
+                    "use of the array context, set "
+                    "'_cls_has_array_context_attr=False'.")
 
         if bcast_actx_array_type is None:
             if cls_has_array_context_attr:
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 301bab3..ef7141f 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -252,10 +252,7 @@ def stringify_array_container_tree(ary: ArrayOrContainer) -> str:
         else:
             for key, subary in iterable:
                 key = f"{key} ({type(subary).__name__})"
-                if level == 0:
-                    indent = ""
-                else:
-                    indent = f" |  {' ' * 4 * (level - 1)}"
+                indent = "" if level == 0 else f" |  {' ' * 4 * (level - 1)}"
 
                 lines.append(f"{indent} +-- {key}")
                 rec(lines, subary, level + 1)
@@ -833,7 +830,7 @@ def unflatten(
 
             # {{{ check strides
 
-            if strict and hasattr(template_subary_c, "strides"):
+            if strict and hasattr(template_subary_c, "strides"):  # noqa: SIM102
                 # Checking strides for 0 sized arrays is ill-defined
                 # since they cannot be indexed
                 if (
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index e3ce52a..f7f7be8 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -526,11 +526,9 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
 
                 from pytools import common_prefix
                 name_hint = common_prefix([nh.prefix for nh in name_hint_tags])
-                if name_hint:
-                    # All name_hint_tags shared at least some common prefix.
-                    function_name = f"frozen_{name_hint}"
-                else:
-                    function_name = "frozen_result"
+
+                # All name_hint_tags shared at least some common prefix.
+                function_name = f"frozen_{name_hint}" if name_hint else "frozen_result"
 
                 self._dag_transform_cache[normalized_expr] = (
                         transformed_dag, function_name)
@@ -668,7 +666,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
                     f"array type: got '{type(arg).__name__}', but expected one "
                     f"of {self.array_types}")
 
-            if name is not None:
+            if name is not None:  # noqa: SIM102
                 # Tagging Placeholders with naming-related tags is pointless:
                 # They already have names. It's also counterproductive, as
                 # multiple placeholders with the same name that are not
@@ -897,7 +895,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
                     f"array type: got '{type(arg).__name__}', but expected one "
                     f"of {self.array_types}")
 
-            if name is not None:
+            if name is not None:  # noqa: SIM102
                 # Tagging Placeholders with naming-related tags is pointless:
                 # They already have names. It's also counterproductive, as
                 # multiple placeholders with the same name that are not
diff --git a/doc/make_numpy_coverage_table.py b/doc/make_numpy_coverage_table.py
index 1a5782e..57f833d 100644
--- a/doc/make_numpy_coverage_table.py
+++ b/doc/make_numpy_coverage_table.py
@@ -232,20 +232,19 @@ if __name__ == "__main__":
     parser.add_argument("filename", nargs="?", type=pathlib.Path, default=None)
     args = parser.parse_args()
 
-    import sys
-    if args.filename is not None:
-        outf = open(args.filename, "w")
-    else:
-        outf = sys.stdout
+    def write(outf):
+        outf.write(HEADER)
+        write_array_creation_routines(outf, ctxs)
+        write_array_manipulation_routines(outf, ctxs)
+        write_linear_algebra(outf, ctxs)
+        write_logic_functions(outf, ctxs)
+        write_mathematical_functions(outf, ctxs)
 
     ctxs = initialize_contexts()
 
-    outf.write(HEADER)
-    write_array_creation_routines(outf, ctxs)
-    write_array_manipulation_routines(outf, ctxs)
-    write_linear_algebra(outf, ctxs)
-    write_logic_functions(outf, ctxs)
-    write_mathematical_functions(outf, ctxs)
-
-    if args.filename is not None:
-        outf.close()
+    if args.filename:
+        with open(args.filename, "w") as outf:
+            write(outf)
+    else:
+        import sys
+        write(sys.stdout)
diff --git a/pyproject.toml b/pyproject.toml
index 2e51586..9d6391a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -74,6 +74,7 @@ extend-select = [
     "RUF", # ruff
     "UP",  # pyupgrade
     "W",   # pycodestyle
+    "SIM",
 ]
 extend-ignore = [
     "C90",   # McCabe complexity
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index ad2cbb1..11ccbb1 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -153,10 +153,7 @@ def randn(shape, dtype):
     rng = np.random.default_rng()
     dtype = np.dtype(dtype)
 
-    if shape == 0:
-        ashape = 1
-    else:
-        ashape = shape
+    ashape = 1 if shape == 0 else shape
 
     if dtype.kind == "c":
         dtype = np.dtype(f"<f{dtype.itemsize // 2}")
-- 
GitLab