diff --git a/test/test_array.py b/test/test_array.py index 9ff806e3a68dbe96d88a7f1e5a296a7647d614af..116863e8c2bf3f0ddb7cc24e48263ebedfb49326 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1603,6 +1603,22 @@ def test_ravel(ctx_factory, order): x.ravel(order=order)) +def test_arithmetic_on_non_scalars(ctx_factory): + from dataclasses import dataclass + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + @dataclass + class ArrayContainer: + _data: np.ndarray + + def __eq__(self, other): + return ArrayContainer(self._data == other) + + with pytest.raises(TypeError): + ArrayContainer(np.ones(100)) + cl.array.zeros(cq, (10,), dtype=np.float64) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])