diff --git a/pytools/datatable.py b/pytools/datatable.py index 5df1e0c837b2411013986ee1902f73e161dbdbb9..8019f30d5906d3150d28915f844bfefc75938e9c 100644 --- a/pytools/datatable.py +++ b/pytools/datatable.py @@ -123,15 +123,12 @@ class DataTable: def sort(self, columns, reverse=False): col_indices = [self.column_indices[col] for col in columns] - def mycmp(row_a, row_b): - for col_index in col_indices: - this_result = cmp(row_a[col_index], row_b[col_index]) - if this_result: - return this_result + def mykey(row): + return tuple( + row[col_index] + for col_index in col_indices) - return 0 - - self.data.sort(mycmp, reverse=reverse) + self.data.sort(reverse=reverse, key=mykey) def aggregated(self, groupby, agg_column, aggregate_func): gb_indices = [self.column_indices[col] for col in groupby] diff --git a/test/test_data_table.py b/test/test_data_table.py index 0b3343e51fca990d9a81fa700565f05958133edb..2e6fa5ce6978ee88403b83f077d7649d22331e6f 100644 --- a/test/test_data_table.py +++ b/test/test_data_table.py @@ -73,8 +73,8 @@ def test_aggregate_2(): from pytools.datatable import DataTable tbl = DataTable(["step", "value"], zip(range(20), range(20))) agg = tbl.aggregated(["step"], "value", max) - assert agg.column_data("step") == range(20) - assert agg.column_data("value") == range(20) + assert agg.column_data("step") == list(range(20)) + assert agg.column_data("value") == list(range(20)) def test_join():