From d50730415071cdb7c3f0c4c7cca141bd6f390ef7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sat, 23 May 2015 15:08:00 -0400 Subject: [PATCH] Add fallback to counting routines, minor cleanups --- loopy/statistics.py | 41 ++++++++++++++++++++++++++----- test/test_statistics.py | 53 +++++++++++++++++++++++++++++++++-------- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index a1e8c7648..10e5c4c8d 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1,6 +1,4 @@ -from __future__ import division -from __future__ import absolute_import -import six +from __future__ import division, absolute_import __copyright__ = "Copyright (C) 2015 James Stevens" @@ -24,6 +22,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import six # noqa + import loopy as lp import warnings from islpy import dim_type @@ -218,6 +218,36 @@ class SubscriptCounter(CombineMapper): return 0 +def count(kernel, bset): + try: + return bset.card() + except AttributeError: + pass + + if not bset.is_box(): + from loopy.diagnostic import warn + warn(kernel, "count_overestimate", + "Barvinok wrappers are not installed. " + "Counting routines may overestimate the " + "number of integer points in your loop " + "domain.") + + result = None + + for i in range(bset.dim(isl.dim_type.set)): + dmax = bset.dim_max(i) + dmin = bset.dim_min(i) + + length = isl.PwQPolynomial.from_pw_aff(dmax - dmin + 1) + + if result is None: + result = length + else: + result = result * length + + return result + + # to evaluate poly: poly.eval_with_dict(dictionary) def get_op_poly(knl): from loopy.preprocess import preprocess_kernel, infer_unknown_types @@ -233,7 +263,7 @@ def get_op_poly(knl): inames_domain = knl.get_inames_domain(insn_inames) domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) ops = op_counter(insn.expression) - op_poly = op_poly + ops*domain.card() + op_poly = op_poly + ops*count(knl, domain) return op_poly @@ -245,6 +275,5 @@ def get_DRAM_access_poly(knl): # for now just counting subscripts insn_inames = knl.insn_inames(insn) inames_domain = knl.get_inames_domain(insn_inames) domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) - poly += subscript_counter(insn.expression) * domain.card() + poly += subscript_counter(insn.expression) * count(knl, domain) return poly - diff --git a/test/test_statistics.py b/test/test_statistics.py index e68700833..f3919db2c 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1,4 +1,4 @@ -from __future__ import division +from __future__ import division, print_function __copyright__ = "Copyright (C) 2015 James Stevens" @@ -23,14 +23,14 @@ THE SOFTWARE. """ import sys -from pyopencl.tools import ( +from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) from loopy.statistics import * # noqa import numpy as np -def test_op_counter_basic(ctx_factory): +def test_op_counter_basic(): knl = lp.make_kernel( "[n,m,l] -> {[i,k,j]: 0<=i {[i,k,j]: 0<=i6 or k/2==l, g[i,k]*2, g[i,k]+h[i,k]/2) @@ -98,10 +98,10 @@ def test_op_counter_logic(ctx_factory): assert i32 == n*m -def test_op_counter_specialops(ctx_factory): +def test_op_counter_specialops(): knl = lp.make_kernel( - "[n,m,l] -> {[i,k,j]: 0<=i {[i,k,j]: 0<=i 1: exec(sys.argv[1]) -- GitLab