From fd59c90c08307f7327b9dbf8e425fd7f54f6a37c Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Tue, 5 Nov 2024 09:47:25 +0200
Subject: [PATCH] ruff: fix zip strict argument

---
 arraycontext/container/arithmetic.py     | 3 ++-
 arraycontext/container/dataclass.py      | 8 ++++----
 arraycontext/container/traversal.py      | 5 +++--
 arraycontext/impl/jax/fake_numpy.py      | 2 +-
 arraycontext/impl/numpy/fake_numpy.py    | 4 ++--
 arraycontext/impl/pyopencl/fake_numpy.py | 4 ++--
 arraycontext/impl/pytato/__init__.py     | 4 ++--
 arraycontext/impl/pytato/fake_numpy.py   | 2 +-
 arraycontext/loopy.py                    | 2 +-
 test/test_arraycontext.py                | 8 +++++---
 10 files changed, 23 insertions(+), 19 deletions(-)

diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index 72e39a1..22572dc 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -539,7 +539,8 @@ def with_container_arithmetic(
                     _format_binary_op_str(op_str, expr_arg1, expr_arg2)
                     for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip(
                         cls._serialize_init_arrays_code("arg1").items(),
-                        cls._serialize_init_arrays_code("arg2").items())
+                        cls._serialize_init_arrays_code("arg2").items(),
+                        strict=True)
                     })
             bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", {
                     key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py
index 492f0c9..ec4c37f 100644
--- a/arraycontext/container/dataclass.py
+++ b/arraycontext/container/dataclass.py
@@ -31,7 +31,7 @@ THE SOFTWARE.
 """
 
 from dataclasses import Field, fields, is_dataclass
-from typing import Tuple, Union, get_args, get_origin
+from typing import Union, get_args, get_origin
 
 from arraycontext.container import is_array_container_type
 
@@ -100,7 +100,7 @@ def dataclass_array_container(cls: type) -> type:
                 _BaseGenericAlias,
                 _SpecialForm,
             )
-            if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)):
+            if isinstance(f.type, _BaseGenericAlias | _SpecialForm):
                 # NOTE: anything except a Union is not allowed
                 raise TypeError(
                         f"Typing annotation not supported on field '{f.name}': "
@@ -125,8 +125,8 @@ def dataclass_array_container(cls: type) -> type:
 
 def inject_dataclass_serialization(
         cls: type,
-        array_fields: Tuple[Field, ...],
-        non_array_fields: Tuple[Field, ...]) -> type:
+        array_fields: tuple[Field, ...],
+        non_array_fields: tuple[Field, ...]) -> type:
     """Implements :func:`~arraycontext.serialize_container` and
     :func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
 
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 80f38af..62f6354 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -165,10 +165,11 @@ def _multimap_array_container_impl(
 
         for subarys in zip(
                 iterable_template,
-                *[serialize_container(_args[i]) for i in container_indices[1:]]
+                *[serialize_container(_args[i]) for i in container_indices[1:]],
+                strict=True
                 ):
             key = None
-            for i, (subkey, subary) in zip(container_indices, subarys):
+            for i, (subkey, subary) in zip(container_indices, subarys, strict=True):
                 if key is None:
                     key = subkey
                 else:
diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py
index bc9481e..094e8cf 100644
--- a/arraycontext/impl/jax/fake_numpy.py
+++ b/arraycontext/impl/jax/fake_numpy.py
@@ -187,7 +187,7 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
                         [(true_ary if kx_i == ky_i else false_ary)
                             and rec_equal(x_i, y_i)
                             for (kx_i, x_i), (ky_i, y_i)
-                            in zip(serialized_x, serialized_y)],
+                            in zip(serialized_x, serialized_y, strict=True)],
                         true_ary)
 
         return rec_equal(a, b)
diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py
index 8517ab6..f345edc 100644
--- a/arraycontext/impl/numpy/fake_numpy.py
+++ b/arraycontext/impl/numpy/fake_numpy.py
@@ -149,8 +149,8 @@ class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace):
                     [(true_ary if kx_i == ky_i else false_ary)
                         and cast(np.ndarray, self.array_equal(x_i, y_i))
                         for (kx_i, x_i), (ky_i, y_i)
-                        in zip(serialized_x, serialized_y)],
-                    true_ary)
+                        in zip(serialized_x, serialized_y, strict=True)],
+                    initial=true_ary)
 
     def arange(self, *args, **kwargs):
         return np.arange(*args, **kwargs)
diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index ac79245..ae340ca 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -236,7 +236,7 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
                         [(true_ary if kx_i == ky_i else false_ary)
                             and rec_equal(x_i, y_i)
                             for (kx_i, x_i), (ky_i, y_i)
-                            in zip(serialized_x, serialized_y)],
+                            in zip(serialized_x, serialized_y, strict=True)],
                         true_ary)
 
         return rec_equal(a, b)
@@ -346,7 +346,7 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
 
     def where(self, criterion, then, else_):
         def where_inner(inner_crit, inner_then, inner_else):
-            if isinstance(inner_crit, (bool, np.bool_)):
+            if isinstance(inner_crit, bool | np.bool_):
                 return inner_then if inner_crit else inner_else
             return cl_array.if_positive(inner_crit != 0, inner_then, inner_else,
                     queue=self._array_context.queue)
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index e3c830e..1d36971 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -676,7 +676,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
 
         return pt.einsum(spec, *[
             preprocess_arg(name, arg)
-            for name, arg in zip(arg_names, args)
+            for name, arg in zip(arg_names, args, strict=True)
             ]).tagged(_preprocess_array_tags(tagged))
 
     def clone(self):
@@ -905,7 +905,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
 
         return pt.einsum(spec, *[
             preprocess_arg(name, arg)
-            for name, arg in zip(arg_names, args)
+            for name, arg in zip(arg_names, args, strict=True)
             ]).tagged(_preprocess_array_tags(tagged))
 
     def clone(self):
diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index c6508e3..0692eb7 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -203,7 +203,7 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
                         [(true_ary if kx_i == ky_i else false_ary)
                             and rec_equal(x_i, y_i)
                             for (kx_i, x_i), (ky_i, y_i)
-                            in zip(serialized_x, serialized_y)],
+                            in zip(serialized_x, serialized_y, strict=True)],
                         true_ary)
 
         return cast(Array, rec_equal(a, b))
diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py
index a62023b..da71784 100644
--- a/arraycontext/loopy.py
+++ b/arraycontext/loopy.py
@@ -89,7 +89,7 @@ def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes):
         from islpy import make_zero_and_vars
         v = make_zero_and_vars(var_names, params=size_names)
         domain = v[0].domain()
-        for vname, sname in zip(var_names, size_names):
+        for vname, sname in zip(var_names, size_names, strict=True):
             domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname])
 
         domain_bset, = domain.get_basic_sets()
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 5cffb20..47d8e94 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -796,7 +796,7 @@ def test_container_map_on_device_scalar(actx_factory):
         rec_map_reduce_array_container,
     )
 
-    for size, ary in zip(expected_sizes, arys[:-1]):
+    for size, ary in zip(expected_sizes, arys[:-1], strict=True):
         result = map_array_container(lambda x: x, ary)
         assert actx.to_numpy(actx.np.array_equal(result, ary))
         result = rec_map_array_container(lambda x: x, ary)
@@ -827,7 +827,8 @@ def test_container_map(actx_factory):
                 subarray for _, subarray in arg1_iterable]
             arg2_subarrays = [
                 subarray for _, subarray in arg2_iterable]
-            for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
+            for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays,
+                                            strict=True):
                 _check_allclose(f, subarray1, subarray2)
 
     def func(x):
@@ -880,7 +881,8 @@ def test_container_multimap(actx_factory):
                 subarray for _, subarray in arg1_iterable]
             arg2_subarrays = [
                 subarray for _, subarray in arg2_iterable]
-            for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
+            for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays,
+                                            strict=True):
                 _check_allclose(f, subarray1, subarray2)
 
     def func_all_scalar(x, y):
-- 
GitLab