Skip to content
Snippets Groups Projects
Commit d5073041 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add fallback to counting routines, minor cleanups

parent 5aac0721
No related branches found
No related tags found
No related merge requests found
from __future__ import division
from __future__ import absolute_import
import six
from __future__ import division, absolute_import
__copyright__ = "Copyright (C) 2015 James Stevens"
......@@ -24,6 +22,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import six # noqa
import loopy as lp
import warnings
from islpy import dim_type
......@@ -218,6 +218,36 @@ class SubscriptCounter(CombineMapper):
return 0
def count(kernel, bset):
try:
return bset.card()
except AttributeError:
pass
if not bset.is_box():
from loopy.diagnostic import warn
warn(kernel, "count_overestimate",
"Barvinok wrappers are not installed. "
"Counting routines may overestimate the "
"number of integer points in your loop "
"domain.")
result = None
for i in range(bset.dim(isl.dim_type.set)):
dmax = bset.dim_max(i)
dmin = bset.dim_min(i)
length = isl.PwQPolynomial.from_pw_aff(dmax - dmin + 1)
if result is None:
result = length
else:
result = result * length
return result
# to evaluate poly: poly.eval_with_dict(dictionary)
def get_op_poly(knl):
from loopy.preprocess import preprocess_kernel, infer_unknown_types
......@@ -233,7 +263,7 @@ def get_op_poly(knl):
inames_domain = knl.get_inames_domain(insn_inames)
domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
ops = op_counter(insn.expression)
op_poly = op_poly + ops*domain.card()
op_poly = op_poly + ops*count(knl, domain)
return op_poly
......@@ -245,6 +275,5 @@ def get_DRAM_access_poly(knl): # for now just counting subscripts
insn_inames = knl.insn_inames(insn)
inames_domain = knl.get_inames_domain(insn_inames)
domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
poly += subscript_counter(insn.expression) * domain.card()
poly += subscript_counter(insn.expression) * count(knl, domain)
return poly
from __future__ import division
from __future__ import division, print_function
__copyright__ = "Copyright (C) 2015 James Stevens"
......@@ -23,14 +23,14 @@ THE SOFTWARE.
"""
import sys
from pyopencl.tools import (
from pyopencl.tools import ( # noqa
pytest_generate_tests_for_pyopencl
as pytest_generate_tests)
from loopy.statistics import * # noqa
import numpy as np
def test_op_counter_basic(ctx_factory):
def test_op_counter_basic():
knl = lp.make_kernel(
"[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
......@@ -56,7 +56,7 @@ def test_op_counter_basic(ctx_factory):
assert i32 == n*m
def test_op_counter_reduction(ctx_factory):
def test_op_counter_reduction():
knl = lp.make_kernel(
"{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
......@@ -74,10 +74,10 @@ def test_op_counter_reduction(ctx_factory):
assert f32 == 2*n*m*l
def test_op_counter_logic(ctx_factory):
def test_op_counter_logic():
knl = lp.make_kernel(
"[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
"{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
[
"""
e[i,k] = if(not(k<l-2) and k>6 or k/2==l, g[i,k]*2, g[i,k]+h[i,k]/2)
......@@ -98,10 +98,10 @@ def test_op_counter_logic(ctx_factory):
assert i32 == n*m
def test_op_counter_specialops(ctx_factory):
def test_op_counter_specialops():
knl = lp.make_kernel(
"[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
"{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
[
"""
c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0)
......@@ -124,10 +124,10 @@ def test_op_counter_specialops(ctx_factory):
assert i32 == n*m
def test_op_counter_bitwise(ctx_factory):
def test_op_counter_bitwise():
knl = lp.make_kernel(
"[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
"{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
[
"""
c[i, j, k] = (a[i,j,k] | 1) + (b[i,j,k] & 1)
......@@ -143,10 +143,43 @@ def test_op_counter_bitwise(ctx_factory):
m = 256
l = 128
i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
print(poly.dict[np.dtype(np.int32)])
not_there = poly[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
assert i32 == 3*n*m+n*m*l
assert not_there == 0
def test_op_counter_triangular_domain():
knl = lp.make_kernel(
"{[i,j]: 0<=i<n and 0<=j<m and i<j}",
"""
a[i, j] = b[i,j] * 2
""",
name="bitwise", assumptions="n,m >= 1")
knl = lp.add_and_infer_dtypes(knl,
dict(b=np.float64))
expect_fallback = False
import islpy as isl
try:
isl.BasicSet.carod
except AttributeError:
expect_fallback = True
else:
expect_fallback = False
poly = get_op_poly(knl)[np.dtype(np.float64)]
value_dict = dict(m=13, n=200)
flops = poly.eval_with_dict(value_dict)
if expect_fallback:
assert flops == 144
else:
assert flops == 78
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment