diff --git a/test/test_pytato.py b/test/test_pytato.py index b77c08913881baa9031cb162ac30dd6b160b943e..5c3fdb519778891a2c5b1e6596c8b58570f6ecd1 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -266,6 +266,18 @@ def test_tagging_array(): assert any(isinstance(tag, BestArrayTag) for tag in y.tags) +def test_dict_of_named_arrays_comparison(): + # See https://github.com/inducer/pytato/pull/137 + x = pt.make_placeholder("x", (10, 4), float) + dict1 = pt.make_dict_of_named_arrays({"out": 2 * x}) + dict2 = pt.make_dict_of_named_arrays({"out": 2 * x}) + dict3 = pt.make_dict_of_named_arrays({"not_out": 2 * x}) + dict4 = pt.make_dict_of_named_arrays({"out": 3 * x}) + assert dict1 == dict2 + assert dict1 != dict3 + assert dict1 != dict4 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])