From 0c23b6821539b0374cc96141ff1a09e44b3ee888 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 16 Mar 2016 17:59:33 -0500
Subject: [PATCH] Fix types in stats gathering

---
 loopy/statistics.py | 37 +++++++++++++++++++++++++++----------
 1 file changed, 27 insertions(+), 10 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index c273edd54..c5eb3142d 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -568,17 +568,19 @@ def count(kernel, set):
 
 # {{{ get_op_poly
 
-def get_op_poly(knl):
+def get_op_poly(knl, numpy_types=True):
 
     """Count the number of operations in a loopy kernel.
 
     :parameter knl: A :class:`loopy.LoopKernel` whose operations are to be counted.
 
-    :return: A mapping of **{(** :class:`numpy.dtype` **,** :class:`string` **)**
+    :return: A mapping of **{(** *type* **,** :class:`string` **)**
              **:** :class:`islpy.PwQPolynomial` **}**.
 
-             - The :class:`numpy.dtype` specifies the type of the data being
-               operated on.
+             - The *type* specifies the type of the data being
+               accessed. This can be a :class:`numpy.dtype` if
+               *numpy_types* is True, otherwise the internal
+               loopy type.
 
              - The string specifies the operation type as
                *add*, *sub*, *mul*, *div*, *pow*, *shift*, *bw* (bitwise), etc.
@@ -614,8 +616,14 @@ def get_op_poly(knl):
         domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
         ops = op_counter(insn.assignee) + op_counter(insn.expression)
         op_poly = op_poly + ops*count(knl, domain)
-    return op_poly.dict
+    result = op_poly.dict
 
+    if numpy_types:
+        result = dict(
+                ((dtype.numpy_dtype, kind), count)
+                for (dtype, kind), count in six.iteritems(result))
+
+    return result
 # }}}
 
 
@@ -632,18 +640,20 @@ def sum_ops_to_dtypes(op_poly_dict):
 
 
 # {{{ get_gmem_access_poly
-def get_gmem_access_poly(knl):  # for now just counting subscripts
+def get_gmem_access_poly(knl, numpy_types=True):  # for now just counting subscripts
 
     """Count the number of global memory accesses in a loopy kernel.
 
     :parameter knl: A :class:`loopy.LoopKernel` whose DRAM accesses are to be
                     counted.
 
-    :return: A mapping of **{(** :class:`numpy.dtype` **,** :class:`string` **,**
+    :return: A mapping of **{(** *type* **,** :class:`string` **,**
              :class:`string` **)** **:** :class:`islpy.PwQPolynomial` **}**.
 
-             - The :class:`numpy.dtype` specifies the type of the data being
-               accessed.
+             - The *type* specifies the type of the data being
+               accessed. This can be a :class:`numpy.dtype` if
+               *numpy_types* is True, otherwise the internal
+               loopy type.
 
              - The first string in the map key specifies the global memory
                access type as
@@ -728,7 +738,14 @@ def get_gmem_access_poly(knl):  # for now just counting subscripts
             else:
                 subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames)
 
-    return subs_poly.dict
+    result = subs_poly.dict
+
+    if numpy_types:
+        result = dict(
+                ((dtype.numpy_dtype, kind, direction), count)
+                for (dtype, kind, direction), count in six.iteritems(result))
+
+    return result
 
 
 def get_DRAM_access_poly(knl):
-- 
GitLab