diff --git a/pytools/obj_array.py b/pytools/obj_array.py index 7dc49ae0dfefdd16cf77c5856487216f6563ea78..597d212e2e41d34683ad6b2b9008c8f12535bea1 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -1,6 +1,6 @@ -from __future__ import absolute_import -import numpy -from pytools import my_decorator as decorator +from __future__ import absolute_import, division +import numpy as np +from pytools import my_decorator as decorator, MovedFunctionDeprecationWrapper def gen_len(expr): @@ -21,14 +21,14 @@ def gen_slice(expr, slice): def is_obj_array(val): try: - return isinstance(val, numpy.ndarray) and val.dtype == object + return isinstance(val, np.ndarray) and val.dtype == object except AttributeError: return False def to_obj_array(ary): ls = log_shape(ary) - result = numpy.empty(ls, dtype=object) + result = np.empty(ls, dtype=object) from pytools import indices_in_shape for i in indices_in_shape(ls): @@ -45,7 +45,7 @@ def is_field_equal(a, b): def make_obj_array(res_list): - result = numpy.empty((len(res_list),), dtype=object) + result = np.empty((len(res_list),), dtype=object) for i, v in enumerate(res_list): result[i] = v @@ -60,29 +60,33 @@ def setify_field(f): return set([f]) -def hashable_field(f): +def obj_array_to_hashable(f): if is_obj_array(f): return tuple(f) else: return f +hashable_field = MovedFunctionDeprecationWrapper(obj_array_to_hashable) -def field_equal(a, b): + +def obj_array_equal(a, b): a_is_oa = is_obj_array(a) assert a_is_oa == is_obj_array(b) if a_is_oa: - return (a == b).all() + return np.array_equal(a, b) else: return a == b +field_equal = MovedFunctionDeprecationWrapper(obj_array_equal) + def join_fields(*args): res_list = [] for arg in args: if isinstance(arg, list): res_list.extend(arg) - elif isinstance(arg, numpy.ndarray): + elif isinstance(arg, np.ndarray): if log_shape(arg) == (): res_list.append(arg) else: @@ -118,7 +122,7 @@ def with_object_array_or_scalar(f, field, obj_array_only=False): ls = log_shape(field) if ls != (): from pytools import indices_in_shape - result = numpy.zeros(ls, dtype=object) + result = np.zeros(ls, dtype=object) for i in indices_in_shape(ls): result[i] = f(field[i]) return result @@ -142,7 +146,7 @@ def with_object_array_or_scalar_n_args(f, *args): ls = log_shape(args[leading_oa_index]) if ls != (): from pytools import indices_in_shape - result = numpy.zeros(ls, dtype=object) + result = np.zeros(ls, dtype=object) new_args = list(args) for i in indices_in_shape(ls):