diff --git a/pytools/datatable.py b/pytools/datatable.py index 4dda68f2954daad9d2be7a8c305ac4cee978918f..41c56909cf972ee20291080b5b719aa8586850f6 100644 --- a/pytools/datatable.py +++ b/pytools/datatable.py @@ -1,3 +1,5 @@ +from typing import IO, Any, Callable, Iterator, List, Optional, Sequence, Tuple + from pytools import Record @@ -9,7 +11,8 @@ An in-memory relational database table """ -class Row(Record): +# type-ignore-reason: Record is untyped +class Row(Record): # type: ignore[misc] pass @@ -22,7 +25,8 @@ class DataTable: .. automethod:: join """ - def __init__(self, column_names, column_data=None): + def __init__(self, column_names: Sequence[str], + column_data: Optional[List[Any]] = None) -> None: """Construct a new table, with the given C{column_names}. :arg column_names: An indexable of column name strings. @@ -41,26 +45,26 @@ class DataTable: if len(self.column_indices) != len(self.column_names): raise RuntimeError("non-unique column names encountered") - def __bool__(self): + def __bool__(self) -> bool: return bool(self.data) - def __len__(self): + def __len__(self) -> int: return len(self.data) - def __iter__(self): + def __iter__(self) -> Iterator[List[Any]]: return self.data.__iter__() - def __str__(self): + def __str__(self) -> str: """Return a pretty-printed version of the table.""" - def col_width(i): + def col_width(i: int) -> int: width = len(self.column_names[i]) if self: width = max(width, max(len(str(row[i])) for row in self.data)) return width col_widths = [col_width(i) for i in range(len(self.column_names))] - def format_row(row): + def format_row(row: Sequence[str]) -> str: return "|".join([str(cell).ljust(col_width) for cell, col_width in zip(row, col_widths)]) @@ -69,7 +73,7 @@ class DataTable: [format_row(row) for row in self.data] return "\n".join(lines) - def insert(self, **kwargs): + def insert(self, **kwargs: Any) -> None: values = [None for i in range(len(self.column_names))] for key, val in kwargs.items(): @@ -77,16 +81,16 @@ class DataTable: self.insert_row(tuple(values)) - def insert_row(self, values): + def insert_row(self, values: Tuple[Any, ...]) -> None: assert isinstance(values, tuple) assert len(values) == len(self.column_names) self.data.append(values) - def insert_rows(self, rows): + def insert_rows(self, rows: Sequence[Tuple[Any, ...]]) -> None: for row in rows: self.insert_row(row) - def filtered(self, **kwargs): + def filtered(self, **kwargs: Any) -> "DataTable": if not kwargs: return self @@ -108,7 +112,7 @@ class DataTable: return DataTable(self.column_names, result_data) - def get(self, **kwargs): + def get(self, **kwargs: Any) -> Row: filtered = self.filtered(**kwargs) if not filtered: raise RuntimeError("no matching entry for get()") @@ -117,34 +121,35 @@ class DataTable: return Row(dict(list(zip(self.column_names, filtered.data[0])))) - def clear(self): + def clear(self) -> None: del self.data[:] - def copy(self): + def copy(self) -> "DataTable": """Make a copy of the instance, but leave individual rows untouched. If the rows are modified later, they will also be modified in the copy. """ return DataTable(self.column_names, self.data[:]) - def deep_copy(self): + def deep_copy(self) -> "DataTable": """Make a copy of the instance down to the row level. The copy's rows may be modified independently from the original. """ return DataTable(self.column_names, [row[:] for row in self.data]) - def sort(self, columns, reverse=False): + def sort(self, columns: Sequence[str], reverse: bool = False) -> None: col_indices = [self.column_indices[col] for col in columns] - def mykey(row): + def mykey(row: Sequence[Any]) -> Tuple[Any, ...]: return tuple( row[col_index] for col_index in col_indices) self.data.sort(reverse=reverse, key=mykey) - def aggregated(self, groupby, agg_column, aggregate_func): + def aggregated(self, groupby: Sequence[str], agg_column: str, + aggregate_func: Callable[[Sequence[Any]], Any]) -> "DataTable": gb_indices = [self.column_indices[col] for col in groupby] agg_index = self.column_indices[agg_column] @@ -153,8 +158,8 @@ class DataTable: result_data = [] # to pacify pyflakes: - last_values = None - agg_values = None + last_values: Tuple[Any, ...] = () + agg_values: List[Row] = [] for row in self.data: this_values = tuple(row[i] for i in gb_indices) @@ -175,8 +180,9 @@ class DataTable: [self.column_names[i] for i in gb_indices] + [agg_column], result_data) - def join(self, column, other_column, other_table, outer=False): - """Return a tabled joining this and the C{other_table} on C{column}. + def join(self, column: str, other_column: str, other_table: "DataTable", + outer: bool = False) -> "DataTable": + """Return a table joining this and the C{other_table} on C{column}. The new table has the following columns: - C{column}, titled the same as in this table. @@ -187,7 +193,7 @@ class DataTable: by which they are joined. """ # pylint:disable=too-many-locals,too-many-branches - def without(indexable, idx): + def without(indexable: Tuple[str, ...], idx: int) -> Tuple[str, ...]: return indexable[:idx] + indexable[idx+1:] this_key_idx = self.column_indices[column] @@ -196,9 +202,9 @@ class DataTable: this_iter = self.data.__iter__() other_iter = other_table.data.__iter__() - result_columns = [self.column_names[this_key_idx]] + \ - without(self.column_names, this_key_idx) + \ - without(other_table.column_names, other_key_idx) + result_columns = tuple(self.column_names[this_key_idx]) + \ + without(tuple(self.column_names), this_key_idx) + \ + without(tuple(other_table.column_names), other_key_idx) result_data = [] @@ -266,17 +272,17 @@ class DataTable: return DataTable(result_columns, result_data) - def restricted(self, columns): + def restricted(self, columns: Sequence[str]) -> "DataTable": col_indices = [self.column_indices[col] for col in columns] return DataTable(columns, [[row[i] for i in col_indices] for row in self.data]) - def column_data(self, column): + def column_data(self, column: str) -> List[Tuple[Any, ...]]: col_index = self.column_indices[column] return [row[col_index] for row in self.data] - def write_csv(self, filelike, **kwargs): + def write_csv(self, filelike: IO[Any], **kwargs: Any) -> None: from csv import writer csvwriter = writer(filelike, **kwargs) csvwriter.writerow(self.column_names) diff --git a/run-mypy.sh b/run-mypy.sh index 2d08e14414a52b33baa7c16a4576cd6e68b7e001..39055a8cd2fd878fb36c1d05f51e90b26083cc12 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -1,3 +1,7 @@ #! /bin/bash +set -ex + mypy --show-error-codes pytools + +mypy --strict --follow-imports=skip pytools/datatable.py