diff --git a/pytools/datatable.py b/pytools/datatable.py index 6493412ea31682c0bb3cc5ddf59a0e20f5eda57e..6c9de70f5f4c40652c7e23482c4e509e2c5df4e4 100644 --- a/pytools/datatable.py +++ b/pytools/datatable.py @@ -1,3 +1,9 @@ +from pytools import Record +class Row(Record): + pass + + + class DataTable: """An in-memory relational database table.""" @@ -94,8 +100,7 @@ class DataTable: if len(filtered) > 1: raise RuntimeError, "more than one matching entry for get()" - from pytools import Record - return Record(dict(zip(self.column_names, filtered.data[0]))) + return Row(dict(zip(self.column_names, filtered.data[0]))) def clear(self): del self.data[:] diff --git a/test/test-prefork.py b/test/test-prefork.py deleted file mode 100644 index f4ca52a1a6eb3c78ba51cb83d914d598eec16417..0000000000000000000000000000000000000000 --- a/test/test-prefork.py +++ /dev/null @@ -1,9 +0,0 @@ -import pytools.prefork as pf - -print pf.call_capture_stdout(["nvcc", "--version"]) -pf.enable_prefork() -from time import sleep -print "NOW" -sleep(17) - -print pf.call_capture_stdout(["nvcc", "--version"]) diff --git a/test/test.py b/test/test.py deleted file mode 100644 index 6746ca3923985b6dc6ec5eabba74d2249dbb4118..0000000000000000000000000000000000000000 --- a/test/test.py +++ /dev/null @@ -1,119 +0,0 @@ -from __future__ import division -import unittest - - - - -class TestMathStuff(unittest.TestCase): - def test_variance(self): - data = [4, 7, 13, 16] - - def naive_var(data): - n = len(data) - return (( - sum(di**2 for di in data) - - sum(data)**2/n) - /(n-1)) - - from pytools import variance - orig_variance = variance(data, entire_pop=False) - - assert abs(naive_var(data) - orig_variance) < 1e-15 - - data = [1e9 + x for x in data] - assert abs(variance(data, entire_pop=False) - orig_variance) < 1e-15 - - - - -class TestDataTable(unittest.TestCase): - # data from Wikipedia "join" article - - def get_dept_table(self): - from pytools.datatable import DataTable - dept_table = DataTable(["id", "name"]) - dept_table.insert_row((31, "Sales")) - dept_table.insert_row((33, "Engineering")) - dept_table.insert_row((34, "Clerical")) - dept_table.insert_row((35, "Marketing")) - return dept_table - - def get_employee_table(self): - from pytools.datatable import DataTable - employee_table = DataTable(["lastname", "dept"]) - employee_table.insert_row(("Rafferty", 31)) - employee_table.insert_row(("Jones", 33)) - employee_table.insert_row(("Jasper", 36)) - employee_table.insert_row(("Steinberg", 33)) - employee_table.insert_row(("Robinson", 34)) - employee_table.insert_row(("Smith", 34)) - return employee_table - - def test_len(self): - et = self.get_employee_table() - assert len(et) == 6 - - def test_iter(self): - et = self.get_employee_table() - - count = 0 - for row in et: - count += 1 - assert len(row) == 2 - - assert count == 6 - - def test_insert_and_get(self): - et = self.get_employee_table() - et.insert(dept=33, lastname="Kloeckner") - assert et.get(lastname="Kloeckner").dept == 33 - - def test_filtered(self): - et = self.get_employee_table() - assert len(et.filtered(dept=33)) == 2 - assert len(et.filtered(dept=34)) == 2 - - def test_sort(self): - et = self.get_employee_table() - et.sort(["lastname"]) - assert et.column_data("dept") == [36,33,31,34,34,33] - - def test_aggregate(self): - et = self.get_employee_table() - et.sort(["dept"]) - agg = et.aggregated(["dept"], "lastname", lambda lst: ",".join(lst)) - assert len(agg) == 4 - for dept, lastnames in agg: - lastnames = lastnames.split(",") - for lastname in lastnames: - assert et.get(lastname=lastname).dept == dept - - def test_aggregate_2(self): - from pytools.datatable import DataTable - tbl = DataTable(["step", "value"], zip(range(20), range(20))) - agg = tbl.aggregated(["step"], "value", max) - assert agg.column_data("step") == range(20) - assert agg.column_data("value") == range(20) - - def test_join(self): - et = self.get_employee_table() - dt = self.get_dept_table() - - et.sort(["dept"]) - dt.sort(["id"]) - - inner_joined = et.join("dept", "id", dt) - assert len(inner_joined) == len(et)-1 - for dept, lastname, deptname in inner_joined: - dept_id = et.get(lastname=lastname).dept - assert dept_id == dept - assert dt.get(id=dept_id).name == deptname - - outer_joined = et.join("dept", "id", dt, outer=True) - assert len(outer_joined) == len(et)+1 - - - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_data_table.py b/test/test_data_table.py new file mode 100644 index 0000000000000000000000000000000000000000..721fdec88a1522a76d407b36cfa0a6c59f7c6c92 --- /dev/null +++ b/test/test_data_table.py @@ -0,0 +1,86 @@ +from __future__ import division +# data from Wikipedia "join" article + +def get_dept_table(): + from pytools.datatable import DataTable + dept_table = DataTable(["id", "name"]) + dept_table.insert_row((31, "Sales")) + dept_table.insert_row((33, "Engineering")) + dept_table.insert_row((34, "Clerical")) + dept_table.insert_row((35, "Marketing")) + return dept_table + +def get_employee_table(): + from pytools.datatable import DataTable + employee_table = DataTable(["lastname", "dept"]) + employee_table.insert_row(("Rafferty", 31)) + employee_table.insert_row(("Jones", 33)) + employee_table.insert_row(("Jasper", 36)) + employee_table.insert_row(("Steinberg", 33)) + employee_table.insert_row(("Robinson", 34)) + employee_table.insert_row(("Smith", 34)) + return employee_table + +def test_len(): + et = get_employee_table() + assert len(et) == 6 + +def test_iter(): + et = get_employee_table() + + count = 0 + for row in et: + count += 1 + assert len(row) == 2 + + assert count == 6 + +def test_insert_and_get(): + et = get_employee_table() + et.insert(dept=33, lastname="Kloeckner") + assert et.get(lastname="Kloeckner").dept == 33 + +def test_filtered(): + et = get_employee_table() + assert len(et.filtered(dept=33)) == 2 + assert len(et.filtered(dept=34)) == 2 + +def test_sort(): + et = get_employee_table() + et.sort(["lastname"]) + assert et.column_data("dept") == [36,33,31,34,34,33] + +def test_aggregate(): + et = get_employee_table() + et.sort(["dept"]) + agg = et.aggregated(["dept"], "lastname", lambda lst: ",".join(lst)) + assert len(agg) == 4 + for dept, lastnames in agg: + lastnames = lastnames.split(",") + for lastname in lastnames: + assert et.get(lastname=lastname).dept == dept + +def test_aggregate_2(): + from pytools.datatable import DataTable + tbl = DataTable(["step", "value"], zip(range(20), range(20))) + agg = tbl.aggregated(["step"], "value", max) + assert agg.column_data("step") == range(20) + assert agg.column_data("value") == range(20) + +def test_join(): + et = get_employee_table() + dt = get_dept_table() + + et.sort(["dept"]) + dt.sort(["id"]) + + inner_joined = et.join("dept", "id", dt) + assert len(inner_joined) == len(et)-1 + for dept, lastname, deptname in inner_joined: + dept_id = et.get(lastname=lastname).dept + assert dept_id == dept + assert dt.get(id=dept_id).name == deptname + + outer_joined = et.join("dept", "id", dt, outer=True) + assert len(outer_joined) == len(et)+1 + diff --git a/test/test_math_stuff.py b/test/test_math_stuff.py new file mode 100644 index 0000000000000000000000000000000000000000..5768c83207330a4b6f2f09249f558347f339ddc9 --- /dev/null +++ b/test/test_math_stuff.py @@ -0,0 +1,23 @@ +from __future__ import division + + + + +def test_variance(): + data = [4, 7, 13, 16] + + def naive_var(data): + n = len(data) + return (( + sum(di**2 for di in data) + - sum(data)**2/n) + /(n-1)) + + from pytools import variance + orig_variance = variance(data, entire_pop=False) + + assert abs(naive_var(data) - orig_variance) < 1e-15 + + data = [1e9 + x for x in data] + assert abs(variance(data, entire_pop=False) - orig_variance) < 1e-15 + diff --git a/test/test_pytools.py b/test/test_pytools.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391