diff --git a/test/test_codegen.py b/test/test_codegen.py index 90cb68b52436d24a0f82ff582fb1ff28336dffb8..bf7c31ca2c4ba7d6c16ffb50701d178fba5735e8 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -827,9 +827,16 @@ def test_reductions(ctx_factory, axis, redn, shape): (" ij -> ", # np.sum [(10, 4)]), - ("dij,ej,ej,dej->ei", + ("dij,ej,ej,dej->ei", # diff: curvimesh [(2, 10, 10), (100, 10), (100, 10), (2, 100, 10)]), + + ("dij,ej,ej,dej->ei", # diff: simplex + [(2, 10, 10), (100, 1), + (100, 1), (2, 100, 10)]), + + ("ij,ij->ij", # broadcasting + [(1, 3), (3, 1)]), ])) def test_einsum(ctx_factory, spec, argshapes): ctx = ctx_factory()