Skip to content
Snippets Groups Projects
Commit 9361b1ca authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fix hashing and equality comparison for reduction operations

parent db22395d
No related branches found
No related tags found
No related merge requests found
......@@ -30,6 +30,10 @@ from loopy.symbolic import FunctionIdentifier
class ReductionOperation(object):
"""Subclasses of this type have to be hashable, picklable, and
equality-comparable.
"""
def result_dtype(self, arg_dtype, inames):
raise NotImplementedError
......@@ -39,6 +43,9 @@ class ReductionOperation(object):
def __call__(self, dtype, operand1, operand2, inames):
raise NotImplementedError
def __ne__(self, other):
return not self.__eq__(other)
class ScalarReductionOperation(ReductionOperation):
def __init__(self, forced_result_dtype=None):
......@@ -53,6 +60,13 @@ class ScalarReductionOperation(ReductionOperation):
return arg_dtype
def __hash__(self):
return hash((type(self), self.forced_result_dtype))
def __eq__(self, other):
return (type(self) == type(other)
and self.forced_result_dtype == other.forced_result_dtype)
def __str__(self):
result = type(self).__name__.replace("ReductionOperation", "").lower()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment