diff --git a/loopy/statistics.py b/loopy/statistics.py index a1e8c764857dcaf786f582696e32f22ef71c78e9..10e5c4c8df960ea232dbbe6f569bd77f440b1081 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1,6 +1,4 @@ -from __future__ import division -from __future__ import absolute_import -import six +from __future__ import division, absolute_import __copyright__ = "Copyright (C) 2015 James Stevens" @@ -24,6 +22,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import six # noqa + import loopy as lp import warnings from islpy import dim_type @@ -218,6 +218,36 @@ class SubscriptCounter(CombineMapper): return 0 +def count(kernel, bset): + try: + return bset.card() + except AttributeError: + pass + + if not bset.is_box(): + from loopy.diagnostic import warn + warn(kernel, "count_overestimate", + "Barvinok wrappers are not installed. " + "Counting routines may overestimate the " + "number of integer points in your loop " + "domain.") + + result = None + + for i in range(bset.dim(isl.dim_type.set)): + dmax = bset.dim_max(i) + dmin = bset.dim_min(i) + + length = isl.PwQPolynomial.from_pw_aff(dmax - dmin + 1) + + if result is None: + result = length + else: + result = result * length + + return result + + # to evaluate poly: poly.eval_with_dict(dictionary) def get_op_poly(knl): from loopy.preprocess import preprocess_kernel, infer_unknown_types @@ -233,7 +263,7 @@ def get_op_poly(knl): inames_domain = knl.get_inames_domain(insn_inames) domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) ops = op_counter(insn.expression) - op_poly = op_poly + ops*domain.card() + op_poly = op_poly + ops*count(knl, domain) return op_poly @@ -245,6 +275,5 @@ def get_DRAM_access_poly(knl): # for now just counting subscripts insn_inames = knl.insn_inames(insn) inames_domain = knl.get_inames_domain(insn_inames) domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) - poly += subscript_counter(insn.expression) * domain.card() + poly += subscript_counter(insn.expression) * count(knl, domain) return poly - diff --git a/test/test_statistics.py b/test/test_statistics.py index e68700833b4a4b41bef5051b030aed92206622c6..f3919db2c91032c86c7e68a81ca62b64c8ff329b 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1,4 +1,4 @@ -from __future__ import division +from __future__ import division, print_function __copyright__ = "Copyright (C) 2015 James Stevens" @@ -23,14 +23,14 @@ THE SOFTWARE. """ import sys -from pyopencl.tools import ( +from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) from loopy.statistics import * # noqa import numpy as np -def test_op_counter_basic(ctx_factory): +def test_op_counter_basic(): knl = lp.make_kernel( "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", @@ -56,7 +56,7 @@ def test_op_counter_basic(ctx_factory): assert i32 == n*m -def test_op_counter_reduction(ctx_factory): +def test_op_counter_reduction(): knl = lp.make_kernel( "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", @@ -74,10 +74,10 @@ def test_op_counter_reduction(ctx_factory): assert f32 == 2*n*m*l -def test_op_counter_logic(ctx_factory): +def test_op_counter_logic(): knl = lp.make_kernel( - "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", + "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", [ """ e[i,k] = if(not(k<l-2) and k>6 or k/2==l, g[i,k]*2, g[i,k]+h[i,k]/2) @@ -98,10 +98,10 @@ def test_op_counter_logic(ctx_factory): assert i32 == n*m -def test_op_counter_specialops(ctx_factory): +def test_op_counter_specialops(): knl = lp.make_kernel( - "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", + "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", [ """ c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0) @@ -124,10 +124,10 @@ def test_op_counter_specialops(ctx_factory): assert i32 == n*m -def test_op_counter_bitwise(ctx_factory): +def test_op_counter_bitwise(): knl = lp.make_kernel( - "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", + "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", [ """ c[i, j, k] = (a[i,j,k] | 1) + (b[i,j,k] & 1) @@ -143,10 +143,43 @@ def test_op_counter_bitwise(ctx_factory): m = 256 l = 128 i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l}) + print(poly.dict[np.dtype(np.int32)]) not_there = poly[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l}) assert i32 == 3*n*m+n*m*l assert not_there == 0 + +def test_op_counter_triangular_domain(): + + knl = lp.make_kernel( + "{[i,j]: 0<=i<n and 0<=j<m and i<j}", + """ + a[i, j] = b[i,j] * 2 + """, + name="bitwise", assumptions="n,m >= 1") + + knl = lp.add_and_infer_dtypes(knl, + dict(b=np.float64)) + + expect_fallback = False + import islpy as isl + try: + isl.BasicSet.carod + except AttributeError: + expect_fallback = True + else: + expect_fallback = False + + poly = get_op_poly(knl)[np.dtype(np.float64)] + value_dict = dict(m=13, n=200) + flops = poly.eval_with_dict(value_dict) + + if expect_fallback: + assert flops == 144 + else: + assert flops == 78 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])