Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
A
arraycontext
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Andreas Klöckner
arraycontext
Commits
91a94fc3
Commit
91a94fc3
authored
3 years ago
by
Kaushik Kulkarni
Committed by
Andreas Klöckner
3 years ago
Browse files
Options
Downloads
Patches
Plain Diff
make the frozen type of PytatoPyOpenCLArrayContext to be TaggableCLArrays
parent
c8427b92
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
arraycontext/impl/pytato/__init__.py
+87
-21
87 additions, 21 deletions
arraycontext/impl/pytato/__init__.py
arraycontext/impl/pytato/compile.py
+46
-10
46 additions, 10 deletions
arraycontext/impl/pytato/compile.py
arraycontext/impl/pytato/utils.py
+1
-0
1 addition, 0 deletions
arraycontext/impl/pytato/utils.py
with
134 additions
and
31 deletions
arraycontext/impl/pytato/__init__.py
+
87
−
21
View file @
91a94fc3
...
@@ -79,6 +79,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
...
@@ -79,6 +79,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
self
.
allocator
=
allocator
self
.
allocator
=
allocator
self
.
array_types
=
(
pt
.
Array
,
)
self
.
array_types
=
(
pt
.
Array
,
)
self
.
_freeze_prg_cache
=
{}
self
.
_freeze_prg_cache
=
{}
self
.
_dag_transform_cache
=
{}
# unused, but necessary to keep the context alive
# unused, but necessary to keep the context alive
self
.
context
=
self
.
queue
.
context
self
.
context
=
self
.
queue
.
context
...
@@ -113,24 +114,56 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
...
@@ -113,24 +114,56 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
return
cl_array
.
get
(
queue
=
self
.
queue
)
return
cl_array
.
get
(
queue
=
self
.
queue
)
def
call_loopy
(
self
,
program
,
**
kwargs
):
def
call_loopy
(
self
,
program
,
**
kwargs
):
import
pyopencl.array
as
cla
from
pytato.scalar_expr
import
SCALAR_CLASSES
from
pytato.loopy
import
call_loopy
from
pytato.loopy
import
call_loopy
from
arraycontext.impl.pyopencl.taggable_cl_array
import
TaggableCLArray
entrypoint
=
program
.
default_entrypoint
.
name
entrypoint
=
program
.
default_entrypoint
.
name
# thaw frozen arrays
# {{{ preprocess args
kwargs
=
{
kw
:
(
self
.
thaw
(
arg
)
if
isinstance
(
arg
,
cla
.
Array
)
else
arg
)
for
kw
,
arg
in
kwargs
.
items
()}
processed_kwargs
=
{}
for
kw
,
arg
in
sorted
(
kwargs
.
items
()):
if
isinstance
(
arg
,
self
.
array_types
+
SCALAR_CLASSES
):
pass
elif
isinstance
(
arg
,
TaggableCLArray
):
arg
=
self
.
thaw
(
arg
)
else
:
raise
ValueError
(
f
"
call_loopy argument
'
{
kw
}
'
expected to be an
"
"
instance of
'
pytato.Array
'
,
'
Number
'
or
"
f
"'
TaggableCLArray
'
, got
'
{
type
(
arg
)
}
'"
)
processed_kwargs
[
kw
]
=
arg
# }}}
return
call_loopy
(
program
,
kwargs
,
entrypoint
)
return
call_loopy
(
program
,
processed_
kwargs
,
entrypoint
)
def
freeze
(
self
,
array
):
def
freeze
(
self
,
array
):
import
pytato
as
pt
import
pytato
as
pt
import
pyopencl.array
as
cla
import
pyopencl.array
as
cla
import
loopy
as
lp
import
loopy
as
lp
from
arraycontext.impl.pytato.utils
import
(
_normalize_pt_expr
,
get_cl_axes_from_pt_axes
)
from
arraycontext.impl.pyopencl.taggable_cl_array
import
(
to_tagged_cl_array
,
TaggableCLArray
)
if
isinstance
(
array
,
cla
.
Array
):
if
isinstance
(
array
,
TaggableCL
Array
):
return
array
.
with_queue
(
None
)
return
array
.
with_queue
(
None
)
if
isinstance
(
array
,
cla
.
Array
):
from
warnings
import
warn
warn
(
"
Freezing pyopencl.array.Array will be deprecated in 2023.
"
"
Use `to_tagged_cl_array` to convert the array to
"
"
TaggableCLArray
"
,
DeprecationWarning
,
stacklevel
=
2
)
return
to_tagged_cl_array
(
array
.
with_queue
(
None
),
axes
=
None
,
tags
=
frozenset
())
if
isinstance
(
array
,
pt
.
DataWrapper
):
# trivial freeze.
return
to_tagged_cl_array
(
array
.
data
.
with_queue
(
None
),
axes
=
get_cl_axes_from_pt_axes
(
array
.
axes
),
tags
=
array
.
tags
)
if
not
isinstance
(
array
,
pt
.
Array
):
if
not
isinstance
(
array
,
pt
.
Array
):
raise
TypeError
(
"
PytatoPyOpenCLArrayContext.freeze invoked with
"
raise
TypeError
(
"
PytatoPyOpenCLArrayContext.freeze invoked with
"
f
"
non-pytato array of type
'
{
type
(
array
)
}
'"
)
f
"
non-pytato array of type
'
{
type
(
array
)
}
'"
)
...
@@ -138,14 +171,16 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
...
@@ -138,14 +171,16 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
# {{{ early exit for 0-sized arrays
# {{{ early exit for 0-sized arrays
if
array
.
size
==
0
:
if
array
.
size
==
0
:
return
cla
.
empty
(
self
.
queue
.
context
,
return
to_tagged_cl_array
(
shape
=
array
.
shape
,
cla
.
empty
(
self
.
queue
.
context
,
dtype
=
array
.
dtype
,
shape
=
array
.
shape
,
allocator
=
self
.
allocator
)
dtype
=
array
.
dtype
,
allocator
=
self
.
allocator
),
get_cl_axes_from_pt_axes
(
array
.
axes
),
array
.
tags
)
# }}}
# }}}
from
arraycontext.impl.pytato.utils
import
_normalize_pt_expr
pt_dict_of_named_arrays
=
pt
.
make_dict_of_named_arrays
(
pt_dict_of_named_arrays
=
pt
.
make_dict_of_named_arrays
(
{
"
_actx_out
"
:
array
})
{
"
_actx_out
"
:
array
})
...
@@ -155,7 +190,13 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
...
@@ -155,7 +190,13 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
try
:
try
:
pt_prg
=
self
.
_freeze_prg_cache
[
normalized_expr
]
pt_prg
=
self
.
_freeze_prg_cache
[
normalized_expr
]
except
KeyError
:
except
KeyError
:
pt_prg
=
pt
.
generate_loopy
(
self
.
transform_dag
(
normalized_expr
),
if
normalized_expr
in
self
.
_dag_transform_cache
:
transformed_dag
=
self
.
_dag_transform_cache
[
normalized_expr
]
else
:
transformed_dag
=
self
.
transform_dag
(
normalized_expr
)
self
.
_dag_transform_cache
[
normalized_expr
]
=
transformed_dag
pt_prg
=
pt
.
generate_loopy
(
transformed_dag
,
options
=
lp
.
Options
(
return_dict
=
True
,
options
=
lp
.
Options
(
return_dict
=
True
,
no_numpy
=
True
),
no_numpy
=
True
),
cl_device
=
self
.
queue
.
device
)
cl_device
=
self
.
queue
.
device
)
...
@@ -166,17 +207,31 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
...
@@ -166,17 +207,31 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
evt
,
out_dict
=
pt_prg
(
self
.
queue
,
**
bound_arguments
)
evt
,
out_dict
=
pt_prg
(
self
.
queue
,
**
bound_arguments
)
evt
.
wait
()
evt
.
wait
()
return
out_dict
[
"
_actx_out
"
].
with_queue
(
None
)
return
to_tagged_cl_array
(
out_dict
[
"
_actx_out
"
].
with_queue
(
None
),
get_cl_axes_from_pt_axes
(
self
.
_dag_transform_cache
[
normalized_expr
][
"
_actx_out
"
].
expr
.
axes
),
array
.
tags
)
def
thaw
(
self
,
array
):
def
thaw
(
self
,
array
):
import
pytato
as
pt
import
pytato
as
pt
import
pyopencl.array
as
cla
from
.utils
import
get_pt_axes_from_cl_axes
from
arraycontext.impl.pyopencl.taggable_cl_array
import
(
TaggableCLArray
,
if
not
isinstance
(
array
,
cla
.
Array
):
to_tagged_cl_array
)
raise
TypeError
(
"
PytatoPyOpenCLArrayContext.thaw expects CL arrays, got
"
import
pyopencl.array
as
cl_array
f
"
{
type
(
array
)
}
"
)
if
isinstance
(
array
,
TaggableCLArray
):
return
pt
.
make_data_wrapper
(
array
.
with_queue
(
self
.
queue
))
pass
elif
isinstance
(
array
,
cl_array
.
Array
):
array
=
to_tagged_cl_array
(
array
,
axes
=
None
,
tags
=
frozenset
())
else
:
raise
TypeError
(
"
PytatoPyOpenCLArrayContext.thaw expects
"
"'
TaggableCLArray
'
or
'
cl.array.Array
'
got
"
f
"
{
type
(
array
)
}
.
"
)
return
pt
.
make_data_wrapper
(
array
.
with_queue
(
self
.
queue
),
axes
=
get_pt_axes_from_cl_axes
(
array
.
axes
),
tags
=
array
.
tags
)
# }}}
# }}}
...
@@ -219,12 +274,23 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
...
@@ -219,12 +274,23 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
def
einsum
(
self
,
spec
,
*
args
,
arg_names
=
None
,
tagged
=
()):
def
einsum
(
self
,
spec
,
*
args
,
arg_names
=
None
,
tagged
=
()):
import
pyopencl.array
as
cla
import
pyopencl.array
as
cla
import
pytato
as
pt
import
pytato
as
pt
from
arraycontext.impl.pyopencl.taggable_cl_array
import
(
TaggableCLArray
,
to_tagged_cl_array
)
if
arg_names
is
None
:
if
arg_names
is
None
:
arg_names
=
(
None
,)
*
len
(
args
)
arg_names
=
(
None
,)
*
len
(
args
)
def
preprocess_arg
(
name
,
arg
):
def
preprocess_arg
(
name
,
arg
):
if
isinstance
(
arg
,
cla
.
Array
):
if
isinstance
(
arg
,
TaggableCL
Array
):
ary
=
self
.
thaw
(
arg
)
ary
=
self
.
thaw
(
arg
)
elif
isinstance
(
arg
,
cla
.
Array
):
from
warnings
import
warn
warn
(
"
Passing pyopencl.array.Array to einsum will be
"
"
deprecated in 2023.
"
"
Use `to_tagged_cl_array` to convert the array to
"
"
TaggableCLArray.
"
,
DeprecationWarning
,
stacklevel
=
2
)
ary
=
self
.
thaw
(
to_tagged_cl_array
(
arg
,
axes
=
None
,
tags
=
frozenset
()))
else
:
else
:
assert
isinstance
(
arg
,
pt
.
Array
)
assert
isinstance
(
arg
,
pt
.
Array
)
ary
=
arg
ary
=
arg
...
...
This diff is collapsed.
Click to expand it.
arraycontext/impl/pytato/compile.py
+
46
−
10
View file @
91a94fc3
...
@@ -34,7 +34,7 @@ from arraycontext.container.traversal import rec_keyed_map_array_container
...
@@ -34,7 +34,7 @@ from arraycontext.container.traversal import rec_keyed_map_array_container
import
abc
import
abc
import
numpy
as
np
import
numpy
as
np
from
typing
import
Any
,
Callable
,
Tuple
,
Dict
,
Mapping
from
typing
import
Any
,
Callable
,
Tuple
,
Dict
,
Mapping
,
FrozenSet
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
pyrsistent
import
pmap
,
PMap
from
pyrsistent
import
pmap
,
PMap
...
@@ -169,7 +169,11 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
...
@@ -169,7 +169,11 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
elif
is_array_container_type
(
arg
.
__class__
):
elif
is_array_container_type
(
arg
.
__class__
):
def
_rec_to_placeholder
(
keys
,
ary
):
def
_rec_to_placeholder
(
keys
,
ary
):
name
=
arg_id_to_name
[(
kw
,)
+
keys
]
name
=
arg_id_to_name
[(
kw
,)
+
keys
]
return
pt
.
make_placeholder
(
name
,
ary
.
shape
,
ary
.
dtype
)
return
pt
.
make_placeholder
(
name
,
ary
.
shape
,
ary
.
dtype
,
axes
=
ary
.
axes
,
tags
=
ary
.
tags
)
return
rec_keyed_map_array_container
(
_rec_to_placeholder
,
arg
)
return
rec_keyed_map_array_container
(
_rec_to_placeholder
,
arg
)
else
:
else
:
...
@@ -204,6 +208,13 @@ class LazilyCompilingFunctionCaller:
...
@@ -204,6 +208,13 @@ class LazilyCompilingFunctionCaller:
with
ProcessLogger
(
logger
,
"
transform_dag
"
):
with
ProcessLogger
(
logger
,
"
transform_dag
"
):
pt_dict_of_named_arrays
=
self
.
actx
.
transform_dag
(
dict_of_named_arrays
)
pt_dict_of_named_arrays
=
self
.
actx
.
transform_dag
(
dict_of_named_arrays
)
name_in_program_to_tags
=
{
name
:
out
.
tags
for
name
,
out
in
pt_dict_of_named_arrays
.
_data
.
items
()}
name_in_program_to_axes
=
{
name
:
out
.
axes
for
name
,
out
in
pt_dict_of_named_arrays
.
_data
.
items
()}
with
ProcessLogger
(
logger
,
"
generate_loopy
"
):
with
ProcessLogger
(
logger
,
"
generate_loopy
"
):
pytato_program
=
pt
.
generate_loopy
(
pt_dict_of_named_arrays
,
pytato_program
=
pt
.
generate_loopy
(
pt_dict_of_named_arrays
,
options
=
lp
.
Options
(
options
=
lp
.
Options
(
...
@@ -225,7 +236,7 @@ class LazilyCompilingFunctionCaller:
...
@@ -225,7 +236,7 @@ class LazilyCompilingFunctionCaller:
.
actx
.
actx
.
transform_loopy_program
))
.
transform_loopy_program
))
return
pytato_program
return
pytato_program
,
name_in_program_to_tags
,
name_in_program_to_axes
def
_dag_to_compiled_func
(
self
,
ary_or_dict_of_named_arrays
,
def
_dag_to_compiled_func
(
self
,
ary_or_dict_of_named_arrays
,
input_id_to_name_in_program
,
output_id_to_name_in_program
,
input_id_to_name_in_program
,
output_id_to_name_in_program
,
...
@@ -234,18 +245,23 @@ class LazilyCompilingFunctionCaller:
...
@@ -234,18 +245,23 @@ class LazilyCompilingFunctionCaller:
output_id
=
"
_pt_out
"
output_id
=
"
_pt_out
"
dict_of_named_arrays
=
pt
.
make_dict_of_named_arrays
(
dict_of_named_arrays
=
pt
.
make_dict_of_named_arrays
(
{
output_id
:
ary_or_dict_of_named_arrays
})
{
output_id
:
ary_or_dict_of_named_arrays
})
pytato_program
=
self
.
_dag_to_transformed_loopy_prg
(
dict_of_named_arrays
)
pytato_program
,
name_in_program_to_tags
,
name_in_program_to_axes
=
(
self
.
_dag_to_transformed_loopy_prg
(
dict_of_named_arrays
))
return
CompiledFunctionReturningArray
(
return
CompiledFunctionReturningArray
(
self
.
actx
,
pytato_program
,
self
.
actx
,
pytato_program
,
input_id_to_name_in_program
=
input_id_to_name_in_program
,
input_id_to_name_in_program
=
input_id_to_name_in_program
,
output_name_in_program
=
output_id
)
output_tags
=
name_in_program_to_tags
[
output_id
],
output_axes
=
name_in_program_to_axes
[
output_id
],
output_name
=
output_id
)
elif
isinstance
(
ary_or_dict_of_named_arrays
,
pt
.
DictOfNamedArrays
):
elif
isinstance
(
ary_or_dict_of_named_arrays
,
pt
.
DictOfNamedArrays
):
pytato_program
=
self
.
_dag_to_transformed_loopy_prg
(
pytato_program
,
name_in_program_to_tags
,
name_in_program_to_axes
=
(
ary_or_dict_of_named_arrays
)
self
.
_dag_to_transformed_loopy_prg
(
ary_or_dict_of_named_arrays
)
)
return
CompiledFunctionReturningArrayContainer
(
return
CompiledFunctionReturningArrayContainer
(
self
.
actx
,
pytato_program
,
self
.
actx
,
pytato_program
,
input_id_to_name_in_program
=
input_id_to_name_in_program
,
input_id_to_name_in_program
=
input_id_to_name_in_program
,
output_id_to_name_in_program
=
output_id_to_name_in_program
,
output_id_to_name_in_program
=
output_id_to_name_in_program
,
name_in_program_to_tags
=
name_in_program_to_tags
,
name_in_program_to_axes
=
name_in_program_to_axes
,
output_template
=
output_template
)
output_template
=
output_template
)
else
:
else
:
raise
NotImplementedError
(
type
(
ary_or_dict_of_named_arrays
))
raise
NotImplementedError
(
type
(
ary_or_dict_of_named_arrays
))
...
@@ -312,6 +328,8 @@ class LazilyCompilingFunctionCaller:
...
@@ -312,6 +328,8 @@ class LazilyCompilingFunctionCaller:
def
_args_to_cl_buffers
(
actx
,
input_id_to_name_in_program
,
arg_id_to_arg
):
def
_args_to_cl_buffers
(
actx
,
input_id_to_name_in_program
,
arg_id_to_arg
):
from
arraycontext.impl.pyopencl.taggable_cl_array
import
TaggableCLArray
input_kwargs_for_loopy
=
{}
input_kwargs_for_loopy
=
{}
for
arg_id
,
arg
in
arg_id_to_arg
.
items
():
for
arg_id
,
arg
in
arg_id_to_arg
.
items
():
...
@@ -320,7 +338,7 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
...
@@ -320,7 +338,7 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
elif
isinstance
(
arg
,
pt
.
array
.
DataWrapper
):
elif
isinstance
(
arg
,
pt
.
array
.
DataWrapper
):
# got a Datwwrapper => simply gets its data
# got a Datwwrapper => simply gets its data
arg
=
arg
.
data
arg
=
arg
.
data
elif
isinstance
(
arg
,
cla
.
Array
):
elif
isinstance
(
arg
,
TaggableCL
Array
):
# got a frozen array => do nothing
# got a frozen array => do nothing
pass
pass
elif
isinstance
(
arg
,
pt
.
Array
):
elif
isinstance
(
arg
,
pt
.
Array
):
...
@@ -383,9 +401,14 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
...
@@ -383,9 +401,14 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
pytato_program
:
pt
.
target
.
BoundProgram
pytato_program
:
pt
.
target
.
BoundProgram
input_id_to_name_in_program
:
Mapping
[
Tuple
[
Any
,
...],
str
]
input_id_to_name_in_program
:
Mapping
[
Tuple
[
Any
,
...],
str
]
output_id_to_name_in_program
:
Mapping
[
Tuple
[
Any
,
...],
str
]
output_id_to_name_in_program
:
Mapping
[
Tuple
[
Any
,
...],
str
]
name_in_program_to_tags
:
Mapping
[
str
,
FrozenSet
[
Tag
]]
name_in_program_to_axes
:
Mapping
[
str
,
Tuple
[
pt
.
Axis
,
...]]
output_template
:
ArrayContainer
output_template
:
ArrayContainer
def
__call__
(
self
,
arg_id_to_arg
)
->
ArrayContainer
:
def
__call__
(
self
,
arg_id_to_arg
)
->
ArrayContainer
:
from
arraycontext.impl.pyopencl.taggable_cl_array
import
to_tagged_cl_array
from
.utils
import
get_cl_axes_from_pt_axes
input_kwargs_for_loopy
=
_args_to_cl_buffers
(
input_kwargs_for_loopy
=
_args_to_cl_buffers
(
self
.
actx
,
self
.
input_id_to_name_in_program
,
arg_id_to_arg
)
self
.
actx
,
self
.
input_id_to_name_in_program
,
arg_id_to_arg
)
...
@@ -399,7 +422,12 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
...
@@ -399,7 +422,12 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
evt
.
wait
()
evt
.
wait
()
def
to_output_template
(
keys
,
_
):
def
to_output_template
(
keys
,
_
):
return
self
.
actx
.
thaw
(
out_dict
[
self
.
output_id_to_name_in_program
[
keys
]])
name_in_program
=
self
.
output_id_to_name_in_program
[
keys
]
return
self
.
actx
.
thaw
(
to_tagged_cl_array
(
out_dict
[
name_in_program
],
axes
=
get_cl_axes_from_pt_axes
(
self
.
name_in_program_to_axes
[
name_in_program
]),
tags
=
self
.
name_in_program_to_tags
[
name_in_program
]))
return
rec_keyed_map_array_container
(
to_output_template
,
return
rec_keyed_map_array_container
(
to_output_template
,
self
.
output_template
)
self
.
output_template
)
...
@@ -415,9 +443,14 @@ class CompiledFunctionReturningArray(CompiledFunction):
...
@@ -415,9 +443,14 @@ class CompiledFunctionReturningArray(CompiledFunction):
actx
:
PytatoPyOpenCLArrayContext
actx
:
PytatoPyOpenCLArrayContext
pytato_program
:
pt
.
target
.
BoundProgram
pytato_program
:
pt
.
target
.
BoundProgram
input_id_to_name_in_program
:
Mapping
[
Tuple
[
Any
,
...],
str
]
input_id_to_name_in_program
:
Mapping
[
Tuple
[
Any
,
...],
str
]
output_tags
:
FrozenSet
[
Tag
]
output_axes
:
Tuple
[
pt
.
Axis
,
...]
output_name
:
str
output_name
:
str
def
__call__
(
self
,
arg_id_to_arg
)
->
ArrayContainer
:
def
__call__
(
self
,
arg_id_to_arg
)
->
ArrayContainer
:
from
arraycontext.impl.pyopencl.taggable_cl_array
import
to_tagged_cl_array
from
.utils
import
get_cl_axes_from_pt_axes
input_kwargs_for_loopy
=
_args_to_cl_buffers
(
input_kwargs_for_loopy
=
_args_to_cl_buffers
(
self
.
actx
,
self
.
input_id_to_name_in_program
,
arg_id_to_arg
)
self
.
actx
,
self
.
input_id_to_name_in_program
,
arg_id_to_arg
)
...
@@ -430,4 +463,7 @@ class CompiledFunctionReturningArray(CompiledFunction):
...
@@ -430,4 +463,7 @@ class CompiledFunctionReturningArray(CompiledFunction):
# running out of memory. This mitigates that risk a bit, for now.
# running out of memory. This mitigates that risk a bit, for now.
evt
.
wait
()
evt
.
wait
()
return
self
.
actx
.
thaw
(
out_dict
[
self
.
output_name
])
return
self
.
actx
.
thaw
(
to_tagged_cl_array
(
out_dict
[
self
.
output_name
],
axes
=
get_cl_axes_from_pt_axes
(
self
.
output_axes
),
tags
=
self
.
output_tags
))
This diff is collapsed.
Click to expand it.
arraycontext/impl/pytato/utils.py
+
1
−
0
View file @
91a94fc3
...
@@ -58,6 +58,7 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
...
@@ -58,6 +58,7 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
shape
=
tuple
(
self
.
rec
(
s
)
if
isinstance
(
s
,
Array
)
else
s
shape
=
tuple
(
self
.
rec
(
s
)
if
isinstance
(
s
,
Array
)
else
s
for
s
in
expr
.
shape
),
for
s
in
expr
.
shape
),
dtype
=
expr
.
dtype
,
dtype
=
expr
.
dtype
,
axes
=
expr
.
axes
,
tags
=
expr
.
tags
)
tags
=
expr
.
tags
)
def
map_size_param
(
self
,
expr
:
SizeParam
)
->
Array
:
def
map_size_param
(
self
,
expr
:
SizeParam
)
->
Array
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment