diff --git a/src/arithmetic_container.py b/src/arithmetic_container.py index 599edeeb1f6e10b2730b9348a739901519d21464..19b08723dfdd208b7c6244de359834652dca36d7 100644 --- a/src/arithmetic_container.py +++ b/src/arithmetic_container.py @@ -207,6 +207,57 @@ def work_with_arithmetic_containers(f, *args, **kwargs): +class ArithmeticListMatrix: + """ A matrix type that operates on ArithmeticLists.""" + def __init__(self, matrix): + """Initialize the ArithmeticListMatrix. + + C{matrix} must allow the following interface: + + - len(matrix) gives the height of the matrix. + - matrix is iterable, giving the rows of the matrix. + + Each row, in turn, must support C{len()} and iteration. + """ + self.matrix = matrix + + def __mul__(self, other): + if not isinstance(other, ArithmeticList): + return NotImplemented + + result = ArithmeticList(None for i in range(len(self.matrix))) + + for i, row in enumerate(self.matrix): + if len(row) != len(other): + raise ValueError, "matrix width does not match ArithmeticList" + + for j, entry in enumerate(row): + if not isinstance(entry, (int, float)) or entry: + if not isinstance(entry, (int, float)) or entry != 1: + contrib = entry * other[j] + else: + contrib = other[j] + + if result[i] is None: + result[i] = contrib + else: + result[i] += contrib + + for i in range(len(result)): + if result[i] is None and len(other): + result[i] = 0 * other[0] + + return result + + def map(self, entry_map): + return ArithmeticListMatrix([[ + entry_map(entry) + for j, entry in enumerate(row)] + for i, row in enumerate(self.matrix)]) + + + + class ArithmeticDictionary(dict): """A dictionary with elementwise (on the values, not the keys) arithmetic operations.""" @@ -346,7 +397,3 @@ class ArithmeticDictionary(dict): for key in self: self[key] ^= other[key] return self - - - -