From 4da8668971c10f28d63d805e306306db6f1985b4 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 7 Oct 2020 23:31:18 -0500 Subject: [PATCH] account for 0-long arrays --- pytato/codegen.py | 8 +++++++- test/test_pytato.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pytato/codegen.py b/pytato/codegen.py index 496b176..fe9075a 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -639,7 +639,13 @@ def domain_for_shape(dim_names: Tuple[str, ...], dom &= aff_from_expr(dom.space, left).le_set(affs[iname]) dom &= affs[iname].lt_set(aff_from_expr(dom.space, right)) - dom, = dom.get_basic_sets() + doms = dom.get_basic_sets() + + if len(doms) == 0: + # empty set + dom = isl.BasicSet.empty(dom.get_space()) + else: + dom, = doms return dom diff --git a/test/test_pytato.py b/test/test_pytato.py index 76e11ef..1923807 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -125,6 +125,17 @@ def test_make_placeholder_noname(): assert x.name in knl.get_read_variables() +def test_zero_length_arrays(): + ns = pt.Namespace() + x = pt.make_placeholder(ns, shape=(0, 4), dtype=float) + y = 2*x + + assert y.shape == (0, 4) + + knl = pt.generate_loopy(y).program + assert all(dom.is_empty() for dom in knl.domains if dom.total_dim() != 0) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab