From afb91639aebd69afecf1f0f3347af3ff0e0cd26c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 17 Oct 2015 13:54:59 -0500
Subject: [PATCH] Some obj_array tweaks

---
 pytools/obj_array.py | 28 ++++++++++++++++------------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/pytools/obj_array.py b/pytools/obj_array.py
index 7dc49ae..597d212 100644
--- a/pytools/obj_array.py
+++ b/pytools/obj_array.py
@@ -1,6 +1,6 @@
-from __future__ import absolute_import
-import numpy
-from pytools import my_decorator as decorator
+from __future__ import absolute_import, division
+import numpy as np
+from pytools import my_decorator as decorator, MovedFunctionDeprecationWrapper
 
 
 def gen_len(expr):
@@ -21,14 +21,14 @@ def gen_slice(expr, slice):
 
 def is_obj_array(val):
     try:
-        return isinstance(val, numpy.ndarray) and val.dtype == object
+        return isinstance(val, np.ndarray) and val.dtype == object
     except AttributeError:
         return False
 
 
 def to_obj_array(ary):
     ls = log_shape(ary)
-    result = numpy.empty(ls, dtype=object)
+    result = np.empty(ls, dtype=object)
 
     from pytools import indices_in_shape
     for i in indices_in_shape(ls):
@@ -45,7 +45,7 @@ def is_field_equal(a, b):
 
 
 def make_obj_array(res_list):
-    result = numpy.empty((len(res_list),), dtype=object)
+    result = np.empty((len(res_list),), dtype=object)
     for i, v in enumerate(res_list):
         result[i] = v
 
@@ -60,29 +60,33 @@ def setify_field(f):
         return set([f])
 
 
-def hashable_field(f):
+def obj_array_to_hashable(f):
     if is_obj_array(f):
         return tuple(f)
     else:
         return f
 
+hashable_field = MovedFunctionDeprecationWrapper(obj_array_to_hashable)
 
-def field_equal(a, b):
+
+def obj_array_equal(a, b):
     a_is_oa = is_obj_array(a)
     assert a_is_oa == is_obj_array(b)
 
     if a_is_oa:
-        return (a == b).all()
+        return np.array_equal(a, b)
     else:
         return a == b
 
+field_equal = MovedFunctionDeprecationWrapper(obj_array_equal)
+
 
 def join_fields(*args):
     res_list = []
     for arg in args:
         if isinstance(arg, list):
             res_list.extend(arg)
-        elif isinstance(arg, numpy.ndarray):
+        elif isinstance(arg, np.ndarray):
             if log_shape(arg) == ():
                 res_list.append(arg)
             else:
@@ -118,7 +122,7 @@ def with_object_array_or_scalar(f, field, obj_array_only=False):
         ls = log_shape(field)
     if ls != ():
         from pytools import indices_in_shape
-        result = numpy.zeros(ls, dtype=object)
+        result = np.zeros(ls, dtype=object)
         for i in indices_in_shape(ls):
             result[i] = f(field[i])
         return result
@@ -142,7 +146,7 @@ def with_object_array_or_scalar_n_args(f, *args):
     ls = log_shape(args[leading_oa_index])
     if ls != ():
         from pytools import indices_in_shape
-        result = numpy.zeros(ls, dtype=object)
+        result = np.zeros(ls, dtype=object)
 
         new_args = list(args)
         for i in indices_in_shape(ls):
-- 
GitLab