diff --git a/loopy/statistics.py b/loopy/statistics.py index dbbdb97da947dc30953008350d5d0e94c98980d7..6c7f20d36ddb819c509fcef817927da234075c30 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1236,14 +1236,19 @@ def get_op_map(knl, numpy_types=True, count_redundant_work=False): % type(insn).__name__) if numpy_types: - op_map.count_map = dict((Op( - dtype=op.dtype.numpy_dtype, - name=op.name, - count_granularity=op.count_granularity), - count) - for op, count in six.iteritems(op_map.count_map)) - - return op_map + return ToCountMap( + init_dict=dict( + (Op( + dtype=op.dtype.numpy_dtype, + name=op.name, + count_granularity=op.count_granularity + ) + , ct) + for op, ct in six.iteritems(op_map.count_map)), + val_type=op_map.val_type + ) + else: + return op_map # }}} @@ -1420,19 +1425,22 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, % type(insn).__name__) if numpy_types: - # FIXME: Don't modify in-place - access_map.count_map = dict( - (MemAccess( - mtype=mem_access.mtype, - dtype=mem_access.dtype.numpy_dtype, - stride=mem_access.stride, - direction=mem_access.direction, - variable=mem_access.variable, - count_granularity=mem_access.count_granularity - ), count) - for mem_access, count in six.iteritems(access_map.count_map)) - - return access_map + return ToCountMap( + init_dict=dict( + (MemAccess( + mtype=mem_access.mtype, + dtype=mem_access.dtype.numpy_dtype, + stride=mem_access.stride, + direction=mem_access.direction, + variable=mem_access.variable, + count_granularity=mem_access.count_granularity + ) + , ct) + for mem_access, ct in six.iteritems(access_map.count_map)), + val_type=access_map.val_type + ) + else: + return access_map # }}}