Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
pyopencl
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Andreas Klöckner
pyopencl
Commits
828d7c08
Commit
828d7c08
authored
2 years ago
by
Alexandru Fikl
Committed by
Andreas Klöckner
2 years ago
Browse files
Options
Downloads
Patches
Plain Diff
add types to scan kernels
parent
a1504e46
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pyopencl/scan.py
+83
-49
83 additions, 49 deletions
pyopencl/scan.py
with
83 additions
and
49 deletions
pyopencl/scan.py
+
83
−
49
View file @
828d7c08
...
@@ -22,19 +22,25 @@ limitations under the License.
...
@@ -22,19 +22,25 @@ limitations under the License.
Derived from code within the Thrust project, https://github.com/thrust/thrust/
Derived from code within the Thrust project, https://github.com/thrust/thrust/
"""
"""
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
pyopencl
as
cl
import
pyopencl
as
cl
import
pyopencl.array
# noqa
import
pyopencl.array
from
pyopencl.tools
import
(
dtype_to_ctype
,
bitlog2
,
from
pyopencl.tools
import
(
KernelTemplateBase
,
_process_code_for_macro
,
KernelTemplateBase
,
get_arg_list_scalar_arg_dtypes
,
DtypedArgument
,
bitlog2
,
context_dependent_memoize
,
context_dependent_memoize
,
dtype_to_ctype
,
get_arg_list_scalar_arg_dtypes
,
get_arg_offset_adjuster_code
,
_process_code_for_macro
,
_NumpyTypesKeyBuilder
,
_NumpyTypesKeyBuilder
,
get_arg_offset_adjuster_code
)
)
import
pyopencl._mymako
as
mako
import
pyopencl._mymako
as
mako
from
pyopencl._cluda
import
CLUDA_PREAMBLE
from
pyopencl._cluda
import
CLUDA_PREAMBLE
...
@@ -745,7 +751,7 @@ void ${name_prefix}_final_update(
...
@@ -745,7 +751,7 @@ void ${name_prefix}_final_update(
# {{{ helpers
# {{{ helpers
def
_round_down_to_power_of_2
(
val
)
:
def
_round_down_to_power_of_2
(
val
:
int
)
->
int
:
result
=
2
**
bitlog2
(
val
)
result
=
2
**
bitlog2
(
val
)
if
result
>
val
:
if
result
>
val
:
result
>>=
1
result
>>=
1
...
@@ -839,10 +845,11 @@ _IGNORED_WORDS = set("""
...
@@ -839,10 +845,11 @@ _IGNORED_WORDS = set("""
"""
.
split
())
"""
.
split
())
def
_make_template
(
s
):
def
_make_template
(
s
:
str
):
import
re
leftovers
=
set
()
leftovers
=
set
()
def
replace_id
(
match
)
:
def
replace_id
(
match
:
"
re.Match
"
)
->
str
:
# avoid name clashes with user code by adding 'psc_' prefix to
# avoid name clashes with user code by adding 'psc_' prefix to
# identifiers.
# identifiers.
...
@@ -850,30 +857,28 @@ def _make_template(s):
...
@@ -850,30 +857,28 @@ def _make_template(s):
if
word
in
_IGNORED_WORDS
:
if
word
in
_IGNORED_WORDS
:
return
word
return
word
elif
word
in
_PREFIX_WORDS
:
elif
word
in
_PREFIX_WORDS
:
return
"
psc_
"
+
word
return
f
"
psc_
{
word
}
"
else
:
else
:
leftovers
.
add
(
word
)
leftovers
.
add
(
word
)
return
word
return
word
import
re
s
=
re
.
sub
(
r
"
\b([a-zA-Z0-9_]+)\b
"
,
replace_id
,
s
)
s
=
re
.
sub
(
r
"
\b([a-zA-Z0-9_]+)\b
"
,
replace_id
,
s
)
if
leftovers
:
if
leftovers
:
from
warnings
import
warn
from
warnings
import
warn
warn
(
"
leftover words in identifier prefixing:
"
+
"
"
.
join
(
leftovers
))
warn
(
"
leftover words in identifier prefixing:
"
+
"
"
.
join
(
leftovers
))
return
mako
.
template
.
Template
(
s
,
strict_undefined
=
True
)
return
mako
.
template
.
Template
(
s
,
strict_undefined
=
True
)
# type: ignore
@dataclass
(
frozen
=
True
)
@dataclass
(
frozen
=
True
)
class
_GeneratedScanKernelInfo
:
class
_GeneratedScanKernelInfo
:
scan_src
:
str
scan_src
:
str
kernel_name
:
str
kernel_name
:
str
scalar_arg_dtypes
:
List
[
"
np.dtype
"
]
scalar_arg_dtypes
:
List
[
Optional
[
np
.
dtype
]
]
wg_size
:
int
wg_size
:
int
k_group_size
:
int
k_group_size
:
int
def
build
(
self
,
context
,
options
)
:
def
build
(
self
,
context
:
cl
.
Context
,
options
:
Any
)
->
"
_BuiltScanKernelInfo
"
:
program
=
cl
.
Program
(
context
,
self
.
scan_src
).
build
(
options
)
program
=
cl
.
Program
(
context
,
self
.
scan_src
).
build
(
options
)
kernel
=
getattr
(
program
,
self
.
kernel_name
)
kernel
=
getattr
(
program
,
self
.
kernel_name
)
kernel
.
set_scalar_arg_dtypes
(
self
.
scalar_arg_dtypes
)
kernel
.
set_scalar_arg_dtypes
(
self
.
scalar_arg_dtypes
)
...
@@ -894,10 +899,12 @@ class _BuiltScanKernelInfo:
...
@@ -894,10 +899,12 @@ class _BuiltScanKernelInfo:
class
_GeneratedFinalUpdateKernelInfo
:
class
_GeneratedFinalUpdateKernelInfo
:
source
:
str
source
:
str
kernel_name
:
str
kernel_name
:
str
scalar_arg_dtypes
:
List
[
"
np.dtype
"
]
scalar_arg_dtypes
:
List
[
Optional
[
np
.
dtype
]
]
update_wg_size
:
int
update_wg_size
:
int
def
build
(
self
,
context
,
options
):
def
build
(
self
,
context
:
cl
.
Context
,
options
:
Any
)
->
"
_BuiltFinalUpdateKernelInfo
"
:
program
=
cl
.
Program
(
context
,
self
.
source
).
build
(
options
)
program
=
cl
.
Program
(
context
,
self
.
source
).
build
(
options
)
kernel
=
getattr
(
program
,
self
.
kernel_name
)
kernel
=
getattr
(
program
,
self
.
kernel_name
)
kernel
.
set_scalar_arg_dtypes
(
self
.
scalar_arg_dtypes
)
kernel
.
set_scalar_arg_dtypes
(
self
.
scalar_arg_dtypes
)
...
@@ -916,14 +923,25 @@ class ScanPerformanceWarning(UserWarning):
...
@@ -916,14 +923,25 @@ class ScanPerformanceWarning(UserWarning):
pass
pass
class
_GenericScanKernelBase
:
class
_GenericScanKernelBase
(
ABC
)
:
# {{{ constructor, argument processing
# {{{ constructor, argument processing
def
__init__
(
self
,
ctx
,
dtype
,
def
__init__
(
arguments
,
input_expr
,
scan_expr
,
neutral
,
output_statement
,
self
,
is_segment_start_expr
=
None
,
input_fetch_exprs
=
None
,
ctx
:
cl
.
Context
,
index_dtype
=
np
.
int32
,
dtype
:
Any
,
name_prefix
=
"
scan
"
,
options
=
None
,
preamble
=
""
,
devices
=
None
):
arguments
:
Union
[
str
,
List
[
DtypedArgument
]],
input_expr
:
str
,
scan_expr
:
str
,
neutral
:
Optional
[
str
],
output_statement
:
str
,
is_segment_start_expr
:
Optional
[
str
]
=
None
,
input_fetch_exprs
:
Optional
[
List
[
Tuple
[
str
,
str
,
int
]]]
=
None
,
index_dtype
:
Any
=
np
.
int32
,
name_prefix
:
str
=
"
scan
"
,
options
:
Any
=
None
,
preamble
:
str
=
""
,
devices
:
Optional
[
cl
.
Device
]
=
None
)
->
None
:
"""
"""
:arg ctx: a :class:`pyopencl.Context` within which the code
:arg ctx: a :class:`pyopencl.Context` within which the code
for this scan kernel will be generated.
for this scan kernel will be generated.
...
@@ -1114,8 +1132,9 @@ class _GenericScanKernelBase:
...
@@ -1114,8 +1132,9 @@ class _GenericScanKernelBase:
# }}}
# }}}
def
finish_setup
(
self
):
@abstractmethod
raise
NotImplementedError
def
finish_setup
(
self
)
->
None
:
pass
generic_scan_kernel_cache
=
WriteOncePersistentDict
(
generic_scan_kernel_cache
=
WriteOncePersistentDict
(
...
@@ -1139,10 +1158,9 @@ class GenericScanKernel(_GenericScanKernelBase):
...
@@ -1139,10 +1158,9 @@ class GenericScanKernel(_GenericScanKernelBase):
a = cl.array.arange(queue, 10000, dtype=np.int32)
a = cl.array.arange(queue, 10000, dtype=np.int32)
knl(a, queue=queue)
knl(a, queue=queue)
"""
"""
def
finish_setup
(
self
):
def
finish_setup
(
self
)
->
None
:
# Before generating the kernel, see if it's cached.
# Before generating the kernel, see if it's cached.
from
pyopencl.cache
import
get_device_cache_id
from
pyopencl.cache
import
get_device_cache_id
devices_key
=
tuple
(
get_device_cache_id
(
device
)
devices_key
=
tuple
(
get_device_cache_id
(
device
)
...
@@ -1188,7 +1206,7 @@ class GenericScanKernel(_GenericScanKernelBase):
...
@@ -1188,7 +1206,7 @@ class GenericScanKernel(_GenericScanKernelBase):
self
.
context
,
self
.
options
)
self
.
context
,
self
.
options
)
del
self
.
final_update_gen_info
del
self
.
final_update_gen_info
def
_finish_setup_impl
(
self
):
def
_finish_setup_impl
(
self
)
->
None
:
# {{{ find usable workgroup/k-group size, build first-level scan
# {{{ find usable workgroup/k-group size, build first-level scan
trip_count
=
0
trip_count
=
0
...
@@ -1296,7 +1314,7 @@ class GenericScanKernel(_GenericScanKernelBase):
...
@@ -1296,7 +1314,7 @@ class GenericScanKernel(_GenericScanKernelBase):
second_level_arguments
=
self
.
parsed_args
+
[
second_level_arguments
=
self
.
parsed_args
+
[
VectorArg
(
self
.
dtype
,
"
interval_sums
"
)]
VectorArg
(
self
.
dtype
,
"
interval_sums
"
)]
second_level_build_kwargs
=
{}
second_level_build_kwargs
:
Dict
[
str
,
Optional
[
str
]]
=
{}
if
self
.
is_segmented
:
if
self
.
is_segmented
:
second_level_arguments
.
append
(
second_level_arguments
.
append
(
VectorArg
(
self
.
index_dtype
,
VectorArg
(
self
.
index_dtype
,
...
@@ -1360,12 +1378,14 @@ class GenericScanKernel(_GenericScanKernelBase):
...
@@ -1360,12 +1378,14 @@ class GenericScanKernel(_GenericScanKernelBase):
# {{{ scan kernel build/properties
# {{{ scan kernel build/properties
def
get_local_mem_use
(
self
,
k_group_size
,
wg_size
,
use_bank_conflict_avoidance
):
def
get_local_mem_use
(
self
,
k_group_size
:
int
,
wg_size
:
int
,
use_bank_conflict_avoidance
:
bool
)
->
int
:
arg_dtypes
=
{}
arg_dtypes
=
{}
for
arg
in
self
.
parsed_args
:
for
arg
in
self
.
parsed_args
:
arg_dtypes
[
arg
.
name
]
=
arg
.
dtype
arg_dtypes
[
arg
.
name
]
=
arg
.
dtype
fetch_expr_offsets
=
{}
fetch_expr_offsets
:
Dict
[
str
,
Set
]
=
{}
for
_name
,
arg_name
,
ife_offset
in
self
.
input_fetch_exprs
:
for
_name
,
arg_name
,
ife_offset
in
self
.
input_fetch_exprs
:
fetch_expr_offsets
.
setdefault
(
arg_name
,
set
()).
add
(
ife_offset
)
fetch_expr_offsets
.
setdefault
(
arg_name
,
set
()).
add
(
ife_offset
)
...
@@ -1388,10 +1408,17 @@ class GenericScanKernel(_GenericScanKernelBase):
...
@@ -1388,10 +1408,17 @@ class GenericScanKernel(_GenericScanKernelBase):
for
arg_name
,
ife_offsets
in
list
(
fetch_expr_offsets
.
items
())
for
arg_name
,
ife_offsets
in
list
(
fetch_expr_offsets
.
items
())
if
-
1
in
ife_offsets
or
len
(
ife_offsets
)
>
1
))
if
-
1
in
ife_offsets
or
len
(
ife_offsets
)
>
1
))
def
generate_scan_kernel
(
self
,
max_wg_size
,
arguments
,
input_expr
,
def
generate_scan_kernel
(
is_segment_start_expr
,
input_fetch_exprs
,
is_first_level
,
self
,
store_segment_start_flags
,
k_group_size
,
max_wg_size
:
int
,
use_bank_conflict_avoidance
):
arguments
:
List
[
DtypedArgument
],
input_expr
:
str
,
is_segment_start_expr
:
Optional
[
str
],
input_fetch_exprs
:
List
[
Tuple
[
str
,
str
,
int
]],
is_first_level
:
bool
,
store_segment_start_flags
:
bool
,
k_group_size
:
int
,
use_bank_conflict_avoidance
:
bool
)
->
_GeneratedScanKernelInfo
:
scalar_arg_dtypes
=
get_arg_list_scalar_arg_dtypes
(
arguments
)
scalar_arg_dtypes
=
get_arg_list_scalar_arg_dtypes
(
arguments
)
# Empirically found on Nv hardware: no need to be bigger than this size
# Empirically found on Nv hardware: no need to be bigger than this size
...
@@ -1437,7 +1464,7 @@ class GenericScanKernel(_GenericScanKernelBase):
...
@@ -1437,7 +1464,7 @@ class GenericScanKernel(_GenericScanKernelBase):
# }}}
# }}}
def
__call__
(
self
,
*
args
,
**
kwargs
)
:
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
cl
.
Event
:
# {{{ argument processing
# {{{ argument processing
allocator
=
kwargs
.
get
(
"
allocator
"
)
allocator
=
kwargs
.
get
(
"
allocator
"
)
...
@@ -1451,8 +1478,8 @@ class GenericScanKernel(_GenericScanKernelBase):
...
@@ -1451,8 +1478,8 @@ class GenericScanKernel(_GenericScanKernelBase):
wait_for
=
list
(
wait_for
)
wait_for
=
list
(
wait_for
)
if
len
(
args
)
!=
len
(
self
.
parsed_args
):
if
len
(
args
)
!=
len
(
self
.
parsed_args
):
raise
TypeError
(
"
expected %d arguments, got %d
"
%
raise
TypeError
(
(
len
(
self
.
parsed_args
)
,
len
(
args
)
)
)
f
"
expected
{
len
(
self
.
parsed_args
)
}
arguments, got
{
len
(
args
)
}
"
)
first_array
=
args
[
self
.
first_array_idx
]
first_array
=
args
[
self
.
first_array_idx
]
allocator
=
allocator
or
first_array
.
allocator
allocator
=
allocator
or
first_array
.
allocator
...
@@ -1631,7 +1658,7 @@ void ${name_prefix}_debug_scan(
...
@@ -1631,7 +1658,7 @@ void ${name_prefix}_debug_scan(
class
GenericDebugScanKernel
(
_GenericScanKernelBase
):
class
GenericDebugScanKernel
(
_GenericScanKernelBase
):
def
finish_setup
(
self
):
def
finish_setup
(
self
)
->
None
:
scan_tpl
=
_make_template
(
DEBUG_SCAN_TEMPLATE
)
scan_tpl
=
_make_template
(
DEBUG_SCAN_TEMPLATE
)
scan_src
=
str
(
scan_tpl
.
render
(
scan_src
=
str
(
scan_tpl
.
render
(
output_statement
=
self
.
output_statement
,
output_statement
=
self
.
output_statement
,
...
@@ -1645,15 +1672,14 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
...
@@ -1645,15 +1672,14 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
**
self
.
code_variables
))
**
self
.
code_variables
))
scan_prg
=
cl
.
Program
(
self
.
context
,
scan_src
).
build
(
self
.
options
)
scan_prg
=
cl
.
Program
(
self
.
context
,
scan_src
).
build
(
self
.
options
)
self
.
kernel
=
getattr
(
self
.
kernel
=
getattr
(
scan_prg
,
f
"
{
self
.
name_prefix
}
_debug_scan
"
)
scan_prg
,
self
.
name_prefix
+
"
_debug_scan
"
)
scalar_arg_dtypes
=
(
scalar_arg_dtypes
=
(
[
None
]
[
None
]
+
get_arg_list_scalar_arg_dtypes
(
self
.
parsed_args
)
+
get_arg_list_scalar_arg_dtypes
(
self
.
parsed_args
)
+
[
self
.
index_dtype
])
+
[
self
.
index_dtype
])
self
.
kernel
.
set_scalar_arg_dtypes
(
scalar_arg_dtypes
)
self
.
kernel
.
set_scalar_arg_dtypes
(
scalar_arg_dtypes
)
def
__call__
(
self
,
*
args
,
**
kwargs
)
:
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
cl
.
Event
:
# {{{ argument processing
# {{{ argument processing
allocator
=
kwargs
.
get
(
"
allocator
"
)
allocator
=
kwargs
.
get
(
"
allocator
"
)
...
@@ -1668,8 +1694,8 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
...
@@ -1668,8 +1694,8 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
wait_for
=
list
(
wait_for
)
wait_for
=
list
(
wait_for
)
if
len
(
args
)
!=
len
(
self
.
parsed_args
):
if
len
(
args
)
!=
len
(
self
.
parsed_args
):
raise
TypeError
(
"
expected %d arguments, got %d
"
%
raise
TypeError
(
(
len
(
self
.
parsed_args
)
,
len
(
args
)
)
)
f
"
expected
{
len
(
self
.
parsed_args
)
}
arguments, got
{
len
(
args
)
}
"
)
first_array
=
args
[
self
.
first_array_idx
]
first_array
=
args
[
self
.
first_array_idx
]
allocator
=
allocator
or
first_array
.
allocator
allocator
=
allocator
or
first_array
.
allocator
...
@@ -1763,15 +1789,23 @@ class ExclusiveScanKernel(_LegacyScanKernelBase):
...
@@ -1763,15 +1789,23 @@ class ExclusiveScanKernel(_LegacyScanKernelBase):
# {{{ template
# {{{ template
class
ScanTemplate
(
KernelTemplateBase
):
class
ScanTemplate
(
KernelTemplateBase
):
def
__init__
(
self
,
def
__init__
(
arguments
,
input_expr
,
scan_expr
,
neutral
,
output_statement
,
self
,
is_segment_start_expr
=
None
,
input_fetch_exprs
=
None
,
arguments
:
Union
[
str
,
List
[
DtypedArgument
]],
name_prefix
=
"
scan
"
,
preamble
=
""
,
template_processor
=
None
):
input_expr
:
str
,
scan_expr
:
str
,
neutral
:
Optional
[
str
],
output_statement
:
str
,
is_segment_start_expr
:
Optional
[
str
]
=
None
,
input_fetch_exprs
:
Optional
[
List
[
Tuple
[
str
,
str
,
int
]]]
=
None
,
name_prefix
:
str
=
"
scan
"
,
preamble
:
str
=
""
,
template_processor
:
Any
=
None
)
->
None
:
super
().
__init__
(
template_processor
=
template_processor
)
if
input_fetch_exprs
is
None
:
if
input_fetch_exprs
is
None
:
input_fetch_exprs
=
[]
input_fetch_exprs
=
[]
KernelTemplateBase
.
__init__
(
self
,
template_processor
=
template_processor
)
self
.
arguments
=
arguments
self
.
arguments
=
arguments
self
.
input_expr
=
input_expr
self
.
input_expr
=
input_expr
self
.
scan_expr
=
scan_expr
self
.
scan_expr
=
scan_expr
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment