diff --git a/src/arithmetic_container.py b/src/arithmetic_container.py index 0c8fa99ec1bec0b8a2db00632b40420c757b0d98..1e962f1c5e0cf6788e532cff92601b286bcb8bd6 100644 --- a/src/arithmetic_container.py +++ b/src/arithmetic_container.py @@ -222,6 +222,9 @@ def work_with_arithmetic_containers(f, *args, **kwargs): formal_kwargs[name] = SimpleKwArg(name) if lists: + from pytools import all_equal + assert all_equal(len(lst) for lst in lists) + return ArithmeticList( f( *list(formal_arg.eval(tp) for formal_arg in formal_args), @@ -235,9 +238,9 @@ def work_with_arithmetic_containers(f, *args, **kwargs): -def outer_product(al1, al2): +def outer_product(al1, al2, mult_op=operator.mul): return ArithmeticListMatrix( - [[al1i*al2i for al2i in al2] for al1i in al1] + [[mult_op(al1i, al2i) for al2i in al2] for al1i in al1] ) @@ -257,9 +260,9 @@ class ArithmeticListMatrix: """ self.matrix = matrix - def __mul__(self, other): + def times(self, other, mult_op): if not isinstance(other, ArithmeticList): - return NotImplemented + raise NotImplementedError result = ArithmeticList(None for i in range(len(self.matrix))) @@ -270,7 +273,7 @@ class ArithmeticListMatrix: for j, entry in enumerate(row): if not isinstance(entry, (int, float)) or entry: if not isinstance(entry, (int, float)) or entry != 1: - contrib = entry * other[j] + contrib = mult_op(entry, other[j]) else: contrib = other[j] @@ -285,6 +288,13 @@ class ArithmeticListMatrix: return result + def __mul__(self, other): + if not isinstance(other, ArithmeticList): + return NotImplemented + + from operator import mul + return self.times(other, mul) + def map(self, entry_map): return ArithmeticListMatrix([[ entry_map(entry)