Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
L
loopy
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Andreas Klöckner
loopy
Commits
06196fa5
Commit
06196fa5
authored
9 years ago
by
Machine Owner
Browse files
Options
Downloads
Patches
Plain Diff
removed hardcoded datatypes from TypeToOpCountMap, updated tests
parent
8ea48395
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
loopy/statistics.py
+50
-63
50 additions, 63 deletions
loopy/statistics.py
test/test_statistics.py
+74
-49
74 additions, 49 deletions
test/test_statistics.py
with
124 additions
and
112 deletions
loopy/statistics.py
+
50
−
63
View file @
06196fa5
...
@@ -24,7 +24,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
...
@@ -24,7 +24,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
THE SOFTWARE.
"""
"""
import
numpy
as
np
import
loopy
as
lp
import
loopy
as
lp
import
warnings
import
warnings
from
islpy
import
dim_type
from
islpy
import
dim_type
...
@@ -32,46 +31,42 @@ import islpy._isl as isl
...
@@ -32,46 +31,42 @@ import islpy._isl as isl
from
pymbolic.mapper
import
CombineMapper
from
pymbolic.mapper
import
CombineMapper
class
Type
dPolyDict
:
class
Type
ToOpCountMap
:
def
__init__
(
self
,
i32
=
0
,
f32
=
0
,
f64
=
0
):
def
__init__
(
self
):
self
.
poly_dict
=
{
self
.
dict
=
{}
np
.
dtype
(
np
.
int32
):
i32
,
np
.
dtype
(
np
.
float32
):
f32
,
np
.
dtype
(
np
.
float64
):
f64
}
def
__add__
(
self
,
TPD
):
def
__add__
(
self
,
other
):
return
TypedPolyDict
(
result
=
TypeToOpCountMap
()
self
[
np
.
dtype
(
np
.
int32
)]
+
TPD
[
np
.
dtype
(
np
.
int32
)],
result
.
dict
=
dict
(
self
.
dict
.
items
()
+
other
.
dict
.
items
()
self
[
np
.
dtype
(
np
.
float32
)]
+
TPD
[
np
.
dtype
(
np
.
float32
)],
+
[(
k
,
self
.
dict
[
k
]
+
other
.
dict
[
k
])
self
[
np
.
dtype
(
np
.
float64
)]
+
TPD
[
np
.
dtype
(
np
.
float64
)])
for
k
in
set
(
self
.
dict
)
&
set
(
other
.
dict
)])
return
result
def
__radd__
(
self
,
other
):
def
__radd__
(
self
,
other
):
if
(
other
!=
0
):
if
(
other
!=
0
):
print
"
ERROR TRYING TO ADD TPD TO NON-ZERO NON-TPD
"
# TODO
message
=
"
TypeToOpCountMap: Attempted to add TypeToOpCountMap to
"
+
\
str
(
type
(
other
))
+
"
"
+
str
(
other
)
+
"
. TypeToOpCountMap
"
+
\
"
may only be added to 0 and other TypeToOpCountMap objects.
"
raise
ValueError
(
message
)
return
return
return
self
return
self
def
__mul__
(
self
,
other
):
def
__mul__
(
self
,
other
):
if
isinstance
(
other
,
isl
.
PwQPolynomial
):
if
isinstance
(
other
,
isl
.
PwQPolynomial
):
re
turn
TypedPolyDict
(
re
sult
=
TypeToOpCountMap
()
self
[
np
.
dtype
(
np
.
int32
)]
*
other
,
for
index
in
self
.
dict
.
keys
():
self
[
np
.
dtype
(
np
.
float32
)
]
*
other
,
result
.
dict
[
index
]
=
self
.
dict
[
index
]
*
other
self
[
np
.
dtype
(
np
.
float64
)]
*
other
)
return
result
else
:
else
:
# TODO
message
=
"
TypeToOpCountMap: Attempted to multiply TypeToOpCountMap by
"
+
\
print
"
ERROR: Cannot multiply TypedPolyDict by type
"
,
type
(
other
)
str
(
type
(
other
))
+
"
"
+
str
(
other
)
+
"
.
"
raise
ValueError
(
message
)
__rmul__
=
__mul__
__rmul__
=
__mul__
def
__getitem__
(
self
,
index
):
return
self
.
poly_dict
[
index
]
def
__setitem__
(
self
,
index
,
value
):
self
.
poly_dict
[
index
]
=
value
def
__str__
(
self
):
def
__str__
(
self
):
return
str
(
self
.
poly_
dict
)
return
str
(
self
.
dict
)
class
ExpressionOpCounter
(
CombineMapper
):
class
ExpressionOpCounter
(
CombineMapper
):
...
@@ -85,13 +80,13 @@ class ExpressionOpCounter(CombineMapper):
...
@@ -85,13 +80,13 @@ class ExpressionOpCounter(CombineMapper):
return
sum
(
values
)
return
sum
(
values
)
def
map_constant
(
self
,
expr
):
def
map_constant
(
self
,
expr
):
return
Type
dPolyDict
(
0
,
0
,
0
)
return
Type
ToOpCountMap
(
)
def
map_tagged_variable
(
self
,
expr
):
def
map_tagged_variable
(
self
,
expr
):
return
Type
dPolyDict
(
0
,
0
,
0
)
return
Type
ToOpCountMap
(
)
def
map_variable
(
self
,
expr
):
# implemented in FlopCounter
def
map_variable
(
self
,
expr
):
return
Type
dPolyDict
(
0
,
0
,
0
)
return
Type
ToOpCountMap
(
)
#def map_wildcard(self, expr):
#def map_wildcard(self, expr):
# return 0,0
# return 0,0
...
@@ -101,7 +96,7 @@ class ExpressionOpCounter(CombineMapper):
...
@@ -101,7 +96,7 @@ class ExpressionOpCounter(CombineMapper):
def
map_call
(
self
,
expr
):
def
map_call
(
self
,
expr
):
# implemented in CombineMapper (functions in opencl spec)
# implemented in CombineMapper (functions in opencl spec)
return
Type
dPolyDict
(
0
,
0
,
0
)
return
Type
ToOpCountMap
(
)
# def map_call_with_kwargs(self, expr): # implemented in CombineMapper
# def map_call_with_kwargs(self, expr): # implemented in CombineMapper
...
@@ -110,32 +105,32 @@ class ExpressionOpCounter(CombineMapper):
...
@@ -110,32 +105,32 @@ class ExpressionOpCounter(CombineMapper):
# def map_lookup(self, expr): # implemented in CombineMapper
# def map_lookup(self, expr): # implemented in CombineMapper
def
map_sum
(
self
,
expr
):
# implemented in FlopCounter
def
map_sum
(
self
,
expr
):
TPD
=
TypedPolyDict
(
0
,
0
,
0
)
op_count_map
=
TypeToOpCountMap
(
)
TPD
[
self
.
type_inf
(
expr
)]
=
len
(
expr
.
children
)
-
1
op_count_map
.
dict
[
self
.
type_inf
(
expr
)]
=
len
(
expr
.
children
)
-
1
if
expr
.
children
:
if
expr
.
children
:
return
TPD
+
sum
(
self
.
rec
(
child
)
for
child
in
expr
.
children
)
return
op_count_map
+
sum
(
self
.
rec
(
child
)
for
child
in
expr
.
children
)
else
:
else
:
return
Type
dPolyDict
(
0
,
0
,
0
)
return
Type
ToOpCountMap
(
)
map_product
=
map_sum
map_product
=
map_sum
def
map_quotient
(
self
,
expr
,
*
args
):
def
map_quotient
(
self
,
expr
,
*
args
):
TPD
=
TypedPolyDict
(
0
,
0
,
0
)
op_count_map
=
TypeToOpCountMap
(
)
TPD
[
self
.
type_inf
(
expr
)]
=
1
op_count_map
.
dict
[
self
.
type_inf
(
expr
)]
=
1
return
TPD
+
self
.
rec
(
expr
.
numerator
)
+
self
.
rec
(
expr
.
denominator
)
return
op_count_map
+
self
.
rec
(
expr
.
numerator
)
+
self
.
rec
(
expr
.
denominator
)
map_floor_div
=
map_quotient
map_floor_div
=
map_quotient
def
map_remainder
(
self
,
expr
):
# implemented in CombineMapper
def
map_remainder
(
self
,
expr
):
# implemented in CombineMapper
TPD
=
TypedPolyDict
(
0
,
0
,
0
)
op_count_map
=
TypeToOpCountMap
(
)
TPD
[
self
.
type_inf
(
expr
)]
=
1
op_count_map
.
dict
[
self
.
type_inf
(
expr
)]
=
1
return
TPD
+
self
.
rec
(
expr
.
numerator
)
+
self
.
rec
(
expr
.
denominator
)
return
op_count_map
+
self
.
rec
(
expr
.
numerator
)
+
self
.
rec
(
expr
.
denominator
)
def
map_power
(
self
,
expr
):
# implemented in FlopCounter
def
map_power
(
self
,
expr
):
TPD
=
TypedPolyDict
(
0
,
0
,
0
)
op_count_map
=
TypeToOpCountMap
(
)
TPD
[
self
.
type_inf
(
expr
)]
=
1
op_count_map
.
dict
[
self
.
type_inf
(
expr
)]
=
1
return
TPD
+
self
.
rec
(
expr
.
base
)
+
self
.
rec
(
expr
.
exponent
)
return
op_count_map
+
self
.
rec
(
expr
.
base
)
+
self
.
rec
(
expr
.
exponent
)
def
map_left_shift
(
self
,
expr
):
# implemented in CombineMapper
def
map_left_shift
(
self
,
expr
):
# implemented in CombineMapper
return
self
.
rec
(
expr
.
shiftee
)
+
self
.
rec
(
expr
.
shift
)
# TODO test
return
self
.
rec
(
expr
.
shiftee
)
+
self
.
rec
(
expr
.
shift
)
# TODO test
...
@@ -169,15 +164,10 @@ class ExpressionOpCounter(CombineMapper):
...
@@ -169,15 +164,10 @@ class ExpressionOpCounter(CombineMapper):
def
map_if
(
self
,
expr
):
# implemented in CombineMapper, recurses
def
map_if
(
self
,
expr
):
# implemented in CombineMapper, recurses
warnings
.
warn
(
"
Counting operations as sum of if-statement branches.
"
)
warnings
.
warn
(
"
Counting operations as sum of if-statement branches.
"
)
# return self.rec(expr.condition) + max(
# self.rec(expr.then), self.rec(expr.else_))
return
self
.
rec
(
expr
.
condition
)
+
self
.
rec
(
expr
.
then
)
+
self
.
rec
(
expr
.
else_
)
return
self
.
rec
(
expr
.
condition
)
+
self
.
rec
(
expr
.
then
)
+
self
.
rec
(
expr
.
else_
)
def
map_if_positive
(
self
,
expr
):
# implemented in FlopCounter
def
map_if_positive
(
self
,
expr
):
# implemented in FlopCounter
warnings
.
warn
(
"
Counting operations as sum of if_pos-statement branches.
"
)
warnings
.
warn
(
"
Counting operations as sum of if_pos-statement branches.
"
)
# return self.rec(expr.criterion) + max(
# self.rec(expr.then),
# self.rec(expr.else_))
return
self
.
rec
(
expr
.
criterion
)
+
self
.
rec
(
expr
.
then
)
+
self
.
rec
(
expr
.
else_
)
return
self
.
rec
(
expr
.
criterion
)
+
self
.
rec
(
expr
.
then
)
+
self
.
rec
(
expr
.
else_
)
def
map_min
(
self
,
expr
):
def
map_min
(
self
,
expr
):
...
@@ -187,23 +177,23 @@ class ExpressionOpCounter(CombineMapper):
...
@@ -187,23 +177,23 @@ class ExpressionOpCounter(CombineMapper):
map_max
=
map_min
# implemented in CombineMapper, maps to map_sum; # TODO test
map_max
=
map_min
# implemented in CombineMapper, maps to map_sum; # TODO test
def
map_common_subexpression
(
self
,
expr
):
def
map_common_subexpression
(
self
,
expr
):
raise
NotImplementedError
(
"
OpCounter encountered common_subexpression,
\
raise
NotImplementedError
(
"
OpCounter encountered common_subexpression,
"
map_common_subexpression not implemented.
"
)
"
map_common_subexpression not implemented.
"
)
return
0
return
0
def
map_substitution
(
self
,
expr
):
def
map_substitution
(
self
,
expr
):
raise
NotImplementedError
(
"
OpCounter encountered substitution,
\
raise
NotImplementedError
(
"
OpCounter encountered substitution,
"
map_substitution not implemented.
"
)
"
map_substitution not implemented.
"
)
return
0
return
0
def
map_derivative
(
self
,
expr
):
def
map_derivative
(
self
,
expr
):
raise
NotImplementedError
(
"
OpCounter encountered derivative,
\
raise
NotImplementedError
(
"
OpCounter encountered derivative,
"
map_derivative not implemented.
"
)
"
map_derivative not implemented.
"
)
return
0
return
0
def
map_slice
(
self
,
expr
):
def
map_slice
(
self
,
expr
):
raise
NotImplementedError
(
"
OpCounter encountered slice,
\
raise
NotImplementedError
(
"
OpCounter encountered slice,
"
map_slice not implemented.
"
)
"
map_slice not implemented.
"
)
return
0
return
0
...
@@ -226,7 +216,6 @@ class SubscriptCounter(CombineMapper):
...
@@ -226,7 +216,6 @@ class SubscriptCounter(CombineMapper):
if
tv
.
is_local
:
if
tv
.
is_local
:
# It's shared memory
# It's shared memory
pass
pass
return
1
+
self
.
rec
(
expr
.
index
)
return
1
+
self
.
rec
(
expr
.
index
)
def
map_constant
(
self
,
expr
):
def
map_constant
(
self
,
expr
):
...
@@ -238,12 +227,9 @@ class SubscriptCounter(CombineMapper):
...
@@ -238,12 +227,9 @@ class SubscriptCounter(CombineMapper):
# to evaluate poly: poly.eval_with_dict(dictionary)
# to evaluate poly: poly.eval_with_dict(dictionary)
def
get_op_poly
(
knl
):
def
get_op_poly
(
knl
):
from
loopy.preprocess
import
preprocess_kernel
,
infer_unknown_types
from
loopy.preprocess
import
preprocess_kernel
,
infer_unknown_types
knl
=
infer_unknown_types
(
knl
,
expect_completion
=
True
)
knl
=
infer_unknown_types
(
knl
,
expect_completion
=
True
)
knl
=
preprocess_kernel
(
knl
)
knl
=
preprocess_kernel
(
knl
)
#print knl
op_poly
=
0
op_poly
=
0
op_counter
=
ExpressionOpCounter
(
knl
)
op_counter
=
ExpressionOpCounter
(
knl
)
...
@@ -259,6 +245,7 @@ def get_op_poly(knl):
...
@@ -259,6 +245,7 @@ def get_op_poly(knl):
def
get_DRAM_access_poly
(
knl
):
# for now just counting subscripts
def
get_DRAM_access_poly
(
knl
):
# for now just counting subscripts
raise
NotImplementedError
(
"
get_DRAM_access_poly not yet implemented.
"
)
poly
=
0
poly
=
0
subscript_counter
=
SubscriptCounter
(
knl
)
subscript_counter
=
SubscriptCounter
(
knl
)
for
insn
in
knl
.
instructions
:
for
insn
in
knl
.
instructions
:
...
...
This diff is collapsed.
Click to expand it.
test/test_statistics.py
+
74
−
49
View file @
06196fa5
...
@@ -26,7 +26,8 @@ import sys
...
@@ -26,7 +26,8 @@ import sys
from
pyopencl.tools
import
(
from
pyopencl.tools
import
(
pytest_generate_tests_for_pyopencl
pytest_generate_tests_for_pyopencl
as
pytest_generate_tests
)
as
pytest_generate_tests
)
from
loopy.statistics
import
*
from
loopy.statistics
import
*
# noqa
import
numpy
as
np
def
test_op_counter_basic
(
ctx_factory
):
def
test_op_counter_basic
(
ctx_factory
):
...
@@ -34,95 +35,119 @@ def test_op_counter_basic(ctx_factory):
...
@@ -34,95 +35,119 @@ def test_op_counter_basic(ctx_factory):
knl
=
lp
.
make_kernel
(
knl
=
lp
.
make_kernel
(
"
[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}
"
,
"
[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}
"
,
[
[
"""
"""
c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
e[i, k] = g[i,k]*h[i,k]
e[i, k] = g[i,k]*h[i,k
+1
]
"""
"""
],
],
name
=
"
weird
"
,
assumptions
=
"
n,m,l >= 1
"
)
name
=
"
weird
"
,
assumptions
=
"
n,m,l >= 1
"
)
knl
=
lp
.
add_and_infer_dtypes
(
knl
,
dict
(
a
=
np
.
float32
,
b
=
np
.
float32
,
g
=
np
.
float32
,
h
=
np
.
float32
))
knl
=
lp
.
add_and_infer_dtypes
(
knl
,
dict
(
a
=
np
.
float32
,
b
=
np
.
float32
,
g
=
np
.
float64
,
h
=
np
.
float64
))
poly
=
get_op_poly
(
knl
)
poly
=
get_op_poly
(
knl
)
n
=
512
n
=
512
m
=
256
m
=
256
l
=
128
l
=
128
flops
=
poly
.
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
f32
=
poly
.
dict
[
np
.
dtype
(
np
.
float32
)].
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
assert
flops
==
n
*
m
+
3
*
n
*
m
*
l
f64
=
poly
.
dict
[
np
.
dtype
(
np
.
float64
)].
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
i32
=
poly
.
dict
[
np
.
dtype
(
np
.
int32
)].
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
assert
f32
==
3
*
n
*
m
*
l
assert
f64
==
n
*
m
assert
i32
==
n
*
m
def
test_op_counter_reduction
(
ctx_factory
):
def
test_op_counter_reduction
(
ctx_factory
):
knl
=
lp
.
make_kernel
(
knl
=
lp
.
make_kernel
(
"
{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}
"
,
"
{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}
"
,
[
[
"
c[i, j] = sum(k, a[i, k]*b[k, j])
"
"
c[i, j] = sum(k, a[i, k]*b[k, j])
"
],
],
name
=
"
matmul
"
,
assumptions
=
"
n,m,l >= 1
"
)
name
=
"
matmul
"
,
assumptions
=
"
n,m,l >= 1
"
)
knl
=
lp
.
add_and_infer_dtypes
(
knl
,
dict
(
a
=
np
.
float32
,
b
=
np
.
float32
))
knl
=
lp
.
add_and_infer_dtypes
(
knl
,
dict
(
a
=
np
.
float32
,
b
=
np
.
float32
))
poly
=
get_op_poly
(
knl
)
poly
=
get_op_poly
(
knl
)
n
=
512
n
=
512
m
=
256
m
=
256
l
=
128
l
=
128
flops
=
poly
.
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
f32
=
poly
.
dict
[
np
.
dtype
(
np
.
float32
)].
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
assert
flops
==
2
*
n
*
m
*
l
assert
f32
==
2
*
n
*
m
*
l
def
test_op_counter_logic
(
ctx_factory
):
def
test_op_counter_logic
(
ctx_factory
):
knl
=
lp
.
make_kernel
(
knl
=
lp
.
make_kernel
(
"
[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}
"
,
"
[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}
"
,
[
[
"""
"""
e[i,k] = if(not(k
<
l-2) and k
> l+
6 or k/2
==
l, g[i,k]*
h[i,k]
, g[i,k]+h[i,k]/2
.0
)
e[i,k] = if(not(k
<
l-2) and k
>
6 or k/2==l, g[i,k]*
2
, g[i,k]+h[i,k]/2)
"""
"""
],
],
name
=
"
logic
"
,
assumptions
=
"
n,m,l >= 1
"
)
name
=
"
logic
"
,
assumptions
=
"
n,m,l >= 1
"
)
knl
=
lp
.
add_and_infer_dtypes
(
knl
,
dict
(
g
=
np
.
float32
,
h
=
np
.
float
32
))
knl
=
lp
.
add_and_infer_dtypes
(
knl
,
dict
(
g
=
np
.
float32
,
h
=
np
.
float
64
))
poly
=
get_op_poly
(
knl
)
poly
=
get_op_poly
(
knl
)
n
=
512
n
=
512
m
=
256
m
=
256
l
=
128
l
=
128
flops
=
poly
.
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
f32
=
poly
.
dict
[
np
.
dtype
(
np
.
float32
)].
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
assert
flops
==
5
*
n
*
m
f64
=
poly
.
dict
[
np
.
dtype
(
np
.
float64
)].
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
i32
=
poly
.
dict
[
np
.
dtype
(
np
.
int32
)].
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
assert
f32
==
n
*
m
assert
f64
==
3
*
n
*
m
assert
i32
==
n
*
m
def
test_op_counter_remainder
(
ctx_factory
):
def
test_op_counter_specialops
(
ctx_factory
):
knl
=
lp
.
make_kernel
(
knl
=
lp
.
make_kernel
(
"
[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}
"
,
"
[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}
"
,
[
[
"""
"""
c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0)
c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0)
"""
e[i, k] = (1+g[i,k])**(1+h[i,k+1])
"""
],
],
name
=
"
logic
"
,
assumptions
=
"
n,m,l >= 1
"
)
name
=
"
specialops
"
,
assumptions
=
"
n,m,l >= 1
"
)
knl
=
lp
.
add_and_infer_dtypes
(
knl
,
dict
(
a
=
np
.
float32
,
b
=
np
.
float32
))
knl
=
lp
.
add_and_infer_dtypes
(
knl
,
dict
(
a
=
np
.
float32
,
b
=
np
.
float32
,
g
=
np
.
float64
,
h
=
np
.
float64
))
poly
=
get_op_poly
(
knl
)
poly
=
get_op_poly
(
knl
)
n
=
512
n
=
512
m
=
256
m
=
256
l
=
128
l
=
128
flops
=
poly
.
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
f32
=
poly
.
dict
[
np
.
dtype
(
np
.
float32
)].
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
assert
flops
==
4
*
n
*
m
*
l
f64
=
poly
.
dict
[
np
.
dtype
(
np
.
float64
)].
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
i32
=
poly
.
dict
[
np
.
dtype
(
np
.
int32
)].
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
assert
f32
==
4
*
n
*
m
*
l
assert
f64
==
3
*
n
*
m
assert
i32
==
n
*
m
def
test_op_counter_
power
(
ctx_factory
):
def
test_op_counter_
bitwise
(
ctx_factory
):
knl
=
lp
.
make_kernel
(
knl
=
lp
.
make_kernel
(
"
[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}
"
,
"
[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}
"
,
[
[
"""
"""
c[i, j, k] = a[i,j,k]
**3.0
c[i, j, k] =
(
a[i,j,k]
| 1) + (b[i,j,k] & 1)
e[i, k] = (
1+
g[i,k]
)**(1+
h[i,k+1])
e[i, k] = (g[i,k]
^ k)*(~
h[i,k+1])
"""
"""
],
],
name
=
"
weird
"
,
assumptions
=
"
n,m,l >= 1
"
)
name
=
"
bitwise
"
,
assumptions
=
"
n,m,l >= 1
"
)
knl
=
lp
.
add_and_infer_dtypes
(
knl
,
dict
(
a
=
np
.
float32
,
g
=
np
.
float32
,
h
=
np
.
float32
))
knl
=
lp
.
add_and_infer_dtypes
(
knl
,
dict
(
a
=
np
.
float32
,
b
=
np
.
float32
,
g
=
np
.
float64
,
h
=
np
.
float64
))
poly
=
get_op_poly
(
knl
)
poly
=
get_op_poly
(
knl
)
n
=
512
n
=
512
m
=
256
m
=
256
l
=
128
l
=
128
flops
=
poly
.
eval_with_dict
({
'
n
'
:
n
,
'
m
'
:
m
,
'
l
'
:
l
})
'''
assert
flops
==
4
*
n
*
m
+
n
*
m
*
l
f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({
'
n
'
: n,
'
m
'
: m,
'
l
'
: l})
f64 = poly.dict[np.dtype(np.float64)].eval_with_dict({
'
n
'
: n,
'
m
'
: m,
'
l
'
: l})
i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({
'
n
'
: n,
'
m
'
: m,
'
l
'
: l})
'''
# TODO figure out how these operations should be counted
if
__name__
==
"
__main__
"
:
if
__name__
==
"
__main__
"
:
if
len
(
sys
.
argv
)
>
1
:
if
len
(
sys
.
argv
)
>
1
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment