From c6445465d65a7eedf41ecbd04ad5c92c090a6b94 Mon Sep 17 00:00:00 2001 From: jdsteve2 Date: Mon, 2 Apr 2018 18:04:47 -0500 Subject: [PATCH] function get_iname_strides() eliminating repeated code --- loopy/statistics.py | 107 +++++++++++++++++--------------------------- 1 file changed, 42 insertions(+), 65 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 1fe55111c..c4f8c9e26 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -930,9 +930,13 @@ class GlobalMemAccessCounter(MemAccessCounter): elif isinstance(tag, GroupIndexTag): gid_to_iname[tag.axis] = iname - # create lid_strides dict (strides are coefficents in flattened index) - # i.e., we want {0:A, 1:B, 2:C, ...} where A, B, & C - # come from flattened index [... + C*lid2 + B*lid1 + A*lid0] + # create lid_strides and gid_strides dicts + + # strides are coefficents in flattened index, i.e., we want + # lid_strides = {0:l0, 1:l1, 2:l2, ...} and + # gid_strides = {0:g0, 1:g1, 2:g2, ...}, + # where l0, l1, l2, g0, g1, and g2 come from flattened index + # [... + g2*gid2 + g1*gid1 + g0*gid0 + ... + l2*lid2 + l1*lid1 + l0*lid0] from loopy.symbolic import CoefficientCollector from loopy.kernel.array import FixedStrideArrayDimTag @@ -940,69 +944,42 @@ class GlobalMemAccessCounter(MemAccessCounter): from loopy.symbolic import simplify_using_aff from loopy.diagnostic import ExpressionNotAffineError - lid_strides = {} - for ltag, iname in six.iteritems(lid_to_iname): - ltag_stride = 0 - # check coefficient of this lid for each axis - for idx, axis_tag in zip(index, array.dim_tags): - - try: - coeffs = CoefficientCollector()( - simplify_using_aff(self.knl, idx)) - except ExpressionNotAffineError: - ltag_stride = None - break - - # check if idx contains this lid - try: - coeff_lid = coeffs[Variable(lid_to_iname[ltag])] - except KeyError: - # idx does not contain this lid - continue - - # found coefficient of this lid - # now determine stride - if isinstance(axis_tag, FixedStrideArrayDimTag): - stride = axis_tag.stride - else: - continue - - ltag_stride += stride*coeff_lid - lid_strides[ltag] = ltag_stride - - # create gid_strides dict (strides are coefficents in flattened index) - # i.e., we want {0:A, 1:B, 2:C, ...} where A, B, & C - # come from flattened index [... + C*gid2 + B*gid1 + A*gid0] - - gid_strides = {} - for gtag, iname in six.iteritems(gid_to_iname): - gtag_stride = 0 - # check coefficient of this gid for each axis - for idx, axis_tag in zip(index, array.dim_tags): - - try: - coeffs = CoefficientCollector()( - simplify_using_aff(self.knl, idx)) - except ExpressionNotAffineError: - gtag_stride = None - break - - # check if idx contains this gid - try: - coeff_gid = coeffs[Variable(gid_to_iname[gtag])] - except KeyError: - # idx does not contain this gid - continue - - # found coefficient of this gid - # now determine stride - if isinstance(axis_tag, FixedStrideArrayDimTag): - stride = axis_tag.stride - else: - continue + def get_iname_strides(tag_to_iname_dict): + tag_to_stride_dict = {} + for tag, iname in six.iteritems(tag_to_iname_dict): + total_iname_stride = 0 + # find total stride of this iname for each axis + for idx, axis_tag in zip(index, array.dim_tags): + # collect index coefficients + try: + coeffs = CoefficientCollector()( + simplify_using_aff(self.knl, idx)) + except ExpressionNotAffineError: + total_iname_stride = None + break + + # check if idx contains this iname + try: + coeff = coeffs[Variable(tag_to_iname_dict[tag])] + except KeyError: + # idx does not contain this iname + continue + + # found coefficient of this iname + # now determine stride + if isinstance(axis_tag, FixedStrideArrayDimTag): + axis_tag_stride = axis_tag.stride + else: + continue + + total_iname_stride += axis_tag_stride*coeff + + tag_to_stride_dict[tag] = total_iname_stride + + return tag_to_stride_dict - gtag_stride += stride*coeff_gid - gid_strides[gtag] = gtag_stride + lid_strides = get_iname_strides(lid_to_iname) + gid_strides = get_iname_strides(gid_to_iname) count_granularity = CountGranularity.WORKITEM if ( 0 in lid_strides and lid_strides[0] != 0 -- GitLab