diff --git a/loopy/__init__.py b/loopy/__init__.py index fce380f8a5823b8e3a20ecda78240c02dd52fea2..07d536eb3b680492a5a4827b61f44b7389a4c31e 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -96,7 +96,8 @@ from loopy.transform.parameter import fix_parameters from loopy.preprocess import (preprocess_kernel, realize_reduction, infer_unknown_types) from loopy.schedule import generate_loop_schedules, get_one_scheduled_kernel -from loopy.statistics import (get_op_poly, get_gmem_access_poly, +from loopy.statistics import (get_op_poly, sum_ops_to_dtypes, + get_gmem_access_poly, get_DRAM_access_poly, get_barrier_poly, stringify_stats_mapping, sum_mem_access_to_bytes) from loopy.codegen import generate_code, generate_body @@ -169,7 +170,7 @@ __all__ = [ "generate_loop_schedules", "get_one_scheduled_kernel", "generate_code", "generate_body", - "get_op_poly", "get_gmem_access_poly", "get_DRAM_access_poly", + "get_op_poly", "sum_ops_to_dtypes", "get_gmem_access_poly", "get_DRAM_access_poly", "get_barrier_poly", "stringify_stats_mapping", "sum_mem_access_to_bytes", "CompiledKernel", diff --git a/loopy/statistics.py b/loopy/statistics.py index 8f3981de3138c12c89e44ae9f31dee1bd91dd7c9..6eb6b0057af4e5d346f3047de6eb206492a8f67d 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -459,6 +459,18 @@ def get_op_poly(knl): return op_poly.dict +def sum_ops_to_dtypes(op_poly_dict): + result = {} + for (dtype, kind), v in op_poly_dict.items(): + new_key = dtype + if new_key in result: + result[new_key] += v + else: + result[new_key] = v + + return result + + def get_gmem_access_poly(knl): # for now just counting subscripts """Count the number of global memory accesses in a loopy kernel.