diff --git a/pytato/codegen.py b/pytato/codegen.py index 496b1766bea9d02a88f9f3920eab3490b69dd534..fe9075abfa51df2b4c3dfdf8404b9e47b6b487af 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -639,7 +639,13 @@ def domain_for_shape(dim_names: Tuple[str, ...], dom &= aff_from_expr(dom.space, left).le_set(affs[iname]) dom &= affs[iname].lt_set(aff_from_expr(dom.space, right)) - dom, = dom.get_basic_sets() + doms = dom.get_basic_sets() + + if len(doms) == 0: + # empty set + dom = isl.BasicSet.empty(dom.get_space()) + else: + dom, = doms return dom diff --git a/test/test_pytato.py b/test/test_pytato.py index 76e11efaee25c8c77e915644e45c0842b6b62e0f..1923807b0cd9835904a31fe0045153ab96a7a249 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -125,6 +125,17 @@ def test_make_placeholder_noname(): assert x.name in knl.get_read_variables() +def test_zero_length_arrays(): + ns = pt.Namespace() + x = pt.make_placeholder(ns, shape=(0, 4), dtype=float) + y = 2*x + + assert y.shape == (0, 4) + + knl = pt.generate_loopy(y).program + assert all(dom.is_empty() for dom in knl.domains if dom.total_dim() != 0) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])