diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index bc83cf31f8411b40e8454df869d392a13c11a0ab..9071cbea804cfa0bad54ec3b477029bcb1d0e87c 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -30,6 +30,10 @@ from loopy.symbolic import FunctionIdentifier class ReductionOperation(object): + """Subclasses of this type have to be hashable, picklable, and + equality-comparable. + """ + def result_dtype(self, arg_dtype, inames): raise NotImplementedError @@ -39,6 +43,9 @@ class ReductionOperation(object): def __call__(self, dtype, operand1, operand2, inames): raise NotImplementedError + def __ne__(self, other): + return not self.__eq__(other) + class ScalarReductionOperation(ReductionOperation): def __init__(self, forced_result_dtype=None): @@ -53,6 +60,13 @@ class ScalarReductionOperation(ReductionOperation): return arg_dtype + def __hash__(self): + return hash((type(self), self.forced_result_dtype)) + + def __eq__(self, other): + return (type(self) == type(other) + and self.forced_result_dtype == other.forced_result_dtype) + def __str__(self): result = type(self).__name__.replace("ReductionOperation", "").lower()