diff --git a/pytools/obj_array.py b/pytools/obj_array.py index f5823c474b6189b2ca3c31baeb4f2edda7c5c8b5..38039459e3d703cafd612d1c9b4b0bcd98a92911 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -172,7 +172,7 @@ def rec_obj_array_vectorize(f, ary): def rec_obj_array_vectorized(f): - wrapper = partial(rec_obj_array_vectorized, f) + wrapper = partial(rec_obj_array_vectorize, f) update_wrapper(wrapper, f) return wrapper @@ -214,9 +214,24 @@ def obj_array_vectorize_n_args(f, *args): def obj_array_vectorized_n_args(f): - wrapper = partial(obj_array_vectorize_n_args, f) + # Unfortunately, this can't use partial(), as the callable returned by it + # will not be turned into a bound method upon attribute access. + # This may happen here, because the decorator *could* be used + # on methods, since it can "look past" the leading `self` argument. + # Only exactly function objects receive this treatment. + # + # Spec link: + # https://docs.python.org/3/reference/datamodel.html#the-standard-type-hierarchy + # (under "Instance Methods", quote as of Py3.9.4) + # > Also notice that this transformation only happens for user-defined functions; + # > other callable objects (and all non-callable objects) are retrieved + # > without transformation. + + def wrapper(*args): + return obj_array_vectorize_n_args(f, *args) + update_wrapper(wrapper, f) - return f + return wrapper # {{{ workarounds for https://github.com/numpy/numpy/issues/1740 diff --git a/test/test_pytools.py b/test/test_pytools.py index fb4af28fa926ee52fd8b70e2efe11cfac61a1942..60a7f712b37b1f745511802d7a1d16df1d53bd0d 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -362,6 +362,86 @@ def test_make_obj_array_iteration(): # }}} +# {{{ test obj array vectorization and decorators + +def test_obj_array_vectorize(c=1): + np = pytest.importorskip("numpy") + la = pytest.importorskip("numpy.linalg") + + # {{{ functions + + import pytools.obj_array as obj + + def add_one(ary): + assert ary.dtype.char != "O" + return ary + c + + def two_add_one(x, y): + assert x.dtype.char != "O" and y.dtype.char != "O" + return x * y + c + + @obj.obj_array_vectorized + def vectorized_add_one(ary): + assert ary.dtype.char != "O" + return ary + c + + @obj.obj_array_vectorized_n_args + def vectorized_two_add_one(x, y): + assert x.dtype.char != "O" and y.dtype.char != "O" + return x * y + c + + class Adder: + def __init__(self, c): + self.c = c + + def add(self, ary): + assert ary.dtype.char != "O" + return ary + self.c + + @obj.obj_array_vectorized_n_args + def vectorized_add(self, ary): + assert ary.dtype.char != "O" + return ary + self.c + + adder = Adder(c) + + # }}} + + # {{{ check + + scalar_ary = np.ones(42, dtype=np.float) + object_ary = obj.make_obj_array([scalar_ary, scalar_ary, scalar_ary]) + + for func, vectorizer, nargs in [ + (add_one, obj.obj_array_vectorize, 1), + (two_add_one, obj.obj_array_vectorize_n_args, 2), + (adder.add, obj.obj_array_vectorize, 1), + ]: + input_ary = [scalar_ary] * nargs + result = vectorizer(func, *input_ary) + error = la.norm(result - c - 1) + print(error) + + input_ary = [object_ary] * nargs + result = vectorizer(func, *input_ary) + error = 0 + + for func, nargs in [ + (vectorized_add_one, 1), + (vectorized_two_add_one, 2), + (adder.vectorized_add, 1), + ]: + input_ary = [scalar_ary] * nargs + result = func(*input_ary) + + input_ary = [object_ary] * nargs + result = func(*input_ary) + + # }}} + +# }}} + + def test_tag(): from pytools.tag import Taggable, Tag, UniqueTag, NonUniqueTagError