diff --git a/pytools/datatable.py b/pytools/datatable.py index 88aa522af0e56903bfdbcccd2a0ab456e2154880..5df1e0c837b2411013986ee1902f73e161dbdbb9 100644 --- a/pytools/datatable.py +++ b/pytools/datatable.py @@ -133,6 +133,37 @@ class DataTable: self.data.sort(mycmp, reverse=reverse) + def aggregated(self, groupby, agg_column, aggregate_func): + gb_indices = [self.column_indices[col] for col in groupby] + agg_index = self.column_indices[agg_column] + + first = True + + result_data = [] + + # to pacify pyflakes: + last_values = None + agg_values = None + + for row in self.data: + this_values = tuple(row[i] for i in gb_indices) + if first or this_values != last_values: + if not first: + result_data.append(last_values + (aggregate_func(agg_values),)) + + agg_values = [row[agg_index]] + last_values = this_values + first = False + else: + agg_values.append(row[agg_index]) + + if not first and agg_values: + result_data.append(this_values + (aggregate_func(agg_values),)) + + return DataTable( + [self.column_names[i] for i in gb_indices] + [agg_column], + result_data) + def join(self, column, other_column, other_table, outer=False): """Return a tabled joining this and the C{other_table} on C{column}.