diff --git a/src/arithmetic_container.py b/src/arithmetic_container.py
index 0c8fa99ec1bec0b8a2db00632b40420c757b0d98..1e962f1c5e0cf6788e532cff92601b286bcb8bd6 100644
--- a/src/arithmetic_container.py
+++ b/src/arithmetic_container.py
@@ -222,6 +222,9 @@ def work_with_arithmetic_containers(f, *args, **kwargs):
             formal_kwargs[name] = SimpleKwArg(name)
 
     if lists:
+        from pytools import all_equal
+        assert all_equal(len(lst) for lst in lists)
+
         return ArithmeticList(
                 f(
                     *list(formal_arg.eval(tp) for formal_arg in formal_args), 
@@ -235,9 +238,9 @@ def work_with_arithmetic_containers(f, *args, **kwargs):
 
 
 
-def outer_product(al1, al2):
+def outer_product(al1, al2, mult_op=operator.mul):
     return ArithmeticListMatrix(
-            [[al1i*al2i for al2i in al2] for al1i in al1]
+            [[mult_op(al1i, al2i) for al2i in al2] for al1i in al1]
             )
 
 
@@ -257,9 +260,9 @@ class ArithmeticListMatrix:
         """
         self.matrix = matrix
 
-    def __mul__(self, other):
+    def times(self, other, mult_op):
         if not isinstance(other, ArithmeticList):
-            return NotImplemented
+            raise NotImplementedError
 
         result = ArithmeticList(None for i in range(len(self.matrix)))
 
@@ -270,7 +273,7 @@ class ArithmeticListMatrix:
             for j, entry in enumerate(row):
                 if not isinstance(entry, (int, float)) or entry:
                     if not isinstance(entry, (int, float)) or entry != 1:
-                        contrib = entry * other[j]
+                        contrib = mult_op(entry, other[j])
                     else:
                         contrib = other[j]
 
@@ -285,6 +288,13 @@ class ArithmeticListMatrix:
 
         return result
 
+    def __mul__(self, other):
+        if not isinstance(other, ArithmeticList):
+            return NotImplemented
+
+        from operator import mul
+        return self.times(other, mul)
+
     def map(self, entry_map):
         return ArithmeticListMatrix([[
             entry_map(entry)