Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
L
loopy
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Andreas Klöckner
loopy
Commits
91be1b94
Commit
91be1b94
authored
12 years ago
by
Andreas Klöckner
Browse files
Options
Downloads
Patches
Plain Diff
Improve code generation for constants, 'f' vs no trailing 'f', integer vs non-integer.
parent
76de8414
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
MEMO
+3
-1
3 additions, 1 deletion
MEMO
loopy/codegen/bounds.py
+2
-2
2 additions, 2 deletions
loopy/codegen/bounds.py
loopy/codegen/expression.py
+183
-79
183 additions, 79 deletions
loopy/codegen/expression.py
loopy/codegen/instruction.py
+10
-3
10 additions, 3 deletions
loopy/codegen/instruction.py
with
198 additions
and
85 deletions
MEMO
+
3
−
1
View file @
91be1b94
...
@@ -65,7 +65,6 @@ To-do
...
@@ -65,7 +65,6 @@ To-do
- Scalar insn priority
- Scalar insn priority
- What to do about constants in codegen? (...f suffix, complex types)
- If finding a maximum proves troublesome, move parameters into the domain
- If finding a maximum proves troublesome, move parameters into the domain
...
@@ -123,6 +122,9 @@ Future ideas
...
@@ -123,6 +122,9 @@ Future ideas
Dealt with
Dealt with
^^^^^^^^^^
^^^^^^^^^^
- What to do about constants in codegen? (...f suffix, complex types)
-> dealt with by type contexts
- relating to Multi-Domain
- relating to Multi-Domain
- Make sure that variables that enter into loop bounds are only written
- Make sure that variables that enter into loop bounds are only written
exactly once. [DONE]
exactly once. [DONE]
...
...
This diff is collapsed.
Click to expand it.
loopy/codegen/bounds.py
+
2
−
2
View file @
91be1b94
...
@@ -173,7 +173,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
...
@@ -173,7 +173,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
from
pymbolic
import
var
from
pymbolic
import
var
rhs
+=
iname_coeff
*
var
(
iname
)
rhs
+=
iname_coeff
*
var
(
iname
)
end_conds
.
append
(
"
%s >= 0
"
%
end_conds
.
append
(
"
%s >= 0
"
%
ccm
(
cfm
(
rhs
)))
ccm
(
cfm
(
rhs
)
,
'
i
'
))
else
:
# iname_coeff > 0
else
:
# iname_coeff > 0
kind
,
bound
=
solve_constraint_for_bound
(
cns
,
iname
)
kind
,
bound
=
solve_constraint_for_bound
(
cns
,
iname
)
assert
kind
==
"
>=
"
assert
kind
==
"
>=
"
...
@@ -205,7 +205,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
...
@@ -205,7 +205,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
from
cgen
import
For
from
cgen
import
For
from
loopy.codegen
import
wrap_in
from
loopy.codegen
import
wrap_in
return
wrap_in
(
For
,
return
wrap_in
(
For
,
"
int %s = %s
"
%
(
iname
,
ccm
(
start_expr
)),
"
int %s = %s
"
%
(
iname
,
ccm
(
start_expr
,
'
i
'
)),
"
&&
"
.
join
(
end_conds
),
"
&&
"
.
join
(
end_conds
),
"
++%s
"
%
iname
,
"
++%s
"
%
iname
,
stmt
)
stmt
)
...
...
This diff is collapsed.
Click to expand it.
loopy/codegen/expression.py
+
183
−
79
View file @
91be1b94
...
@@ -2,8 +2,9 @@ from __future__ import division
...
@@ -2,8 +2,9 @@ from __future__ import division
import
numpy
as
np
import
numpy
as
np
from
pymbolic.mapper.c_code
import
CCodeMapper
as
CCodeMapper
from
pymbolic.mapper
import
RecursiveMapper
from
pymbolic.mapper.stringifier
import
PREC_NONE
from
pymbolic.mapper.stringifier
import
(
PREC_NONE
,
PREC_CALL
,
PREC_PRODUCT
,
PREC_POWER
)
from
pymbolic.mapper
import
CombineMapper
from
pymbolic.mapper
import
CombineMapper
# {{{ type inference
# {{{ type inference
...
@@ -57,7 +58,7 @@ class TypeInferenceMapper(CombineMapper):
...
@@ -57,7 +58,7 @@ class TypeInferenceMapper(CombineMapper):
if
isinstance
(
identifier
,
Variable
):
if
isinstance
(
identifier
,
Variable
):
identifier
=
identifier
.
name
identifier
=
identifier
.
name
arg_dtypes
=
tuple
(
self
.
rec
(
par
)
for
par
in
expr
.
parameters
)
arg_dtypes
=
tuple
(
self
.
rec
(
par
,
None
)
for
par
in
expr
.
parameters
)
mangle_result
=
self
.
kernel
.
mangle_function
(
identifier
,
arg_dtypes
)
mangle_result
=
self
.
kernel
.
mangle_function
(
identifier
,
arg_dtypes
)
if
mangle_result
is
not
None
:
if
mangle_result
is
not
None
:
...
@@ -118,7 +119,25 @@ def perform_cast(ccm, expr, expr_dtype, target_dtype):
...
@@ -118,7 +119,25 @@ def perform_cast(ccm, expr, expr_dtype, target_dtype):
# {{{ C code mapper
# {{{ C code mapper
class
LoopyCCodeMapper
(
CCodeMapper
):
# type_context may be:
# - 'i' for integer -
# - 'f' for single-precision floating point
# - 'd' for double-precision floating point
# or None for 'no known context'.
def
dtype_to_type_context
(
dtype
):
dtype
=
np
.
dtype
(
dtype
)
if
dtype
.
kind
==
'
i
'
:
return
'
i
'
if
dtype
in
[
np
.
float64
,
np
.
complex128
]:
return
'
d
'
if
dtype
in
[
np
.
float32
,
np
.
complex64
]:
return
'
f
'
return
None
class
LoopyCCodeMapper
(
RecursiveMapper
):
def
__init__
(
self
,
kernel
,
seen_dtypes
,
seen_functions
,
var_subst_map
=
{},
def
__init__
(
self
,
kernel
,
seen_dtypes
,
seen_functions
,
var_subst_map
=
{},
with_annotation
=
False
,
allow_complex
=
False
):
with_annotation
=
False
,
allow_complex
=
False
):
"""
"""
...
@@ -127,7 +146,6 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -127,7 +146,6 @@ class LoopyCCodeMapper(CCodeMapper):
functions that were encountered.
functions that were encountered.
"""
"""
CCodeMapper
.
__init__
(
self
)
self
.
kernel
=
kernel
self
.
kernel
=
kernel
self
.
seen_dtypes
=
seen_dtypes
self
.
seen_dtypes
=
seen_dtypes
self
.
seen_functions
=
seen_functions
self
.
seen_functions
=
seen_functions
...
@@ -138,6 +156,8 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -138,6 +156,8 @@ class LoopyCCodeMapper(CCodeMapper):
self
.
with_annotation
=
with_annotation
self
.
with_annotation
=
with_annotation
self
.
var_subst_map
=
var_subst_map
.
copy
()
self
.
var_subst_map
=
var_subst_map
.
copy
()
# {{{ copy helpers
def
copy
(
self
,
var_subst_map
=
None
):
def
copy
(
self
,
var_subst_map
=
None
):
if
var_subst_map
is
None
:
if
var_subst_map
is
None
:
var_subst_map
=
self
.
var_subst_map
var_subst_map
=
self
.
var_subst_map
...
@@ -146,11 +166,6 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -146,11 +166,6 @@ class LoopyCCodeMapper(CCodeMapper):
with_annotation
=
self
.
with_annotation
,
with_annotation
=
self
.
with_annotation
,
allow_complex
=
self
.
allow_complex
)
allow_complex
=
self
.
allow_complex
)
def
infer_type
(
self
,
expr
):
result
=
self
.
type_inf_mapper
(
expr
)
self
.
seen_dtypes
.
add
(
result
)
return
result
def
copy_and_assign
(
self
,
name
,
value
):
def
copy_and_assign
(
self
,
name
,
value
):
"""
Make a copy of self with variable *name* fixed to *value*.
"""
"""
Make a copy of self with variable *name* fixed to *value*.
"""
var_subst_map
=
self
.
var_subst_map
.
copy
()
var_subst_map
=
self
.
var_subst_map
.
copy
()
...
@@ -164,18 +179,41 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -164,18 +179,41 @@ class LoopyCCodeMapper(CCodeMapper):
var_subst_map
.
update
(
assignments
)
var_subst_map
.
update
(
assignments
)
return
self
.
copy
(
var_subst_map
=
var_subst_map
)
return
self
.
copy
(
var_subst_map
=
var_subst_map
)
def
map_common_subexpression
(
self
,
expr
,
prec
):
# }}}
# {{{ helpers
def
infer_type
(
self
,
expr
):
result
=
self
.
type_inf_mapper
(
expr
)
self
.
seen_dtypes
.
add
(
result
)
return
result
def
join_rec
(
self
,
joiner
,
iterable
,
prec
,
type_context
):
f
=
joiner
.
join
(
"
%s
"
for
i
in
iterable
)
return
f
%
tuple
(
self
.
rec
(
i
,
prec
,
type_context
)
for
i
in
iterable
)
def
parenthesize_if_needed
(
self
,
s
,
enclosing_prec
,
my_prec
):
if
enclosing_prec
>
my_prec
:
return
"
(%s)
"
%
s
else
:
return
s
# }}}
def
map_common_subexpression
(
self
,
expr
,
prec
,
type_context
):
raise
RuntimeError
(
"
common subexpression should have been eliminated upon
"
raise
RuntimeError
(
"
common subexpression should have been eliminated upon
"
"
entry to loopy
"
)
"
entry to loopy
"
)
def
map_variable
(
self
,
expr
,
prec
):
def
map_variable
(
self
,
expr
,
enclosing_prec
,
type_context
):
if
expr
.
name
in
self
.
var_subst_map
:
if
expr
.
name
in
self
.
var_subst_map
:
if
self
.
with_annotation
:
if
self
.
with_annotation
:
return
"
/* %s */ %s
"
%
(
return
"
/* %s */ %s
"
%
(
expr
.
name
,
expr
.
name
,
self
.
rec
(
self
.
var_subst_map
[
expr
.
name
],
prec
))
self
.
rec
(
self
.
var_subst_map
[
expr
.
name
],
enclosing_prec
,
type_context
))
else
:
else
:
return
str
(
self
.
rec
(
self
.
var_subst_map
[
expr
.
name
],
prec
))
return
str
(
self
.
rec
(
self
.
var_subst_map
[
expr
.
name
],
enclosing_prec
,
type_context
))
elif
expr
.
name
in
self
.
kernel
.
arg_dict
:
elif
expr
.
name
in
self
.
kernel
.
arg_dict
:
arg
=
self
.
kernel
.
arg_dict
[
expr
.
name
]
arg
=
self
.
kernel
.
arg_dict
[
expr
.
name
]
from
loopy.kernel
import
_ShapedArg
from
loopy.kernel
import
_ShapedArg
...
@@ -188,15 +226,22 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -188,15 +226,22 @@ class LoopyCCodeMapper(CCodeMapper):
_
,
c_name
=
result
_
,
c_name
=
result
return
c_name
return
c_name
return
CCodeMapper
.
map_variable
(
self
,
expr
,
prec
)
return
expr
.
name
def
map_tagged_variable
(
self
,
expr
,
enclosing_prec
):
def
map_tagged_variable
(
self
,
expr
,
enclosing_prec
,
type_context
):
return
expr
.
name
return
expr
.
name
def
map_subscript
(
self
,
expr
,
enclosing_prec
):
def
map_subscript
(
self
,
expr
,
enclosing_prec
,
type_context
):
def
base_impl
(
expr
,
enclosing_prec
,
type_context
):
return
self
.
parenthesize_if_needed
(
"
%s[%s]
"
%
(
self
.
rec
(
expr
.
aggregate
,
PREC_CALL
,
type_context
),
self
.
rec
(
expr
.
index
,
PREC_NONE
,
'
i
'
)),
enclosing_prec
,
PREC_CALL
)
from
pymbolic.primitives
import
Variable
from
pymbolic.primitives
import
Variable
if
not
isinstance
(
expr
.
aggregate
,
Variable
):
if
not
isinstance
(
expr
.
aggregate
,
Variable
):
return
CCodeMapper
.
map_subscript
(
self
,
expr
,
enclosing_prec
)
return
base_impl
(
expr
,
enclosing_prec
,
type_context
)
if
expr
.
aggregate
.
name
in
self
.
kernel
.
arg_dict
:
if
expr
.
aggregate
.
name
in
self
.
kernel
.
arg_dict
:
arg
=
self
.
kernel
.
arg_dict
[
expr
.
aggregate
.
name
]
arg
=
self
.
kernel
.
arg_dict
[
expr
.
aggregate
.
name
]
...
@@ -207,7 +252,7 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -207,7 +252,7 @@ class LoopyCCodeMapper(CCodeMapper):
base_access
=
(
"
read_imagef(%s, loopy_sampler, (float%d)(%s))
"
base_access
=
(
"
read_imagef(%s, loopy_sampler, (float%d)(%s))
"
%
(
arg
.
name
,
arg
.
dimensions
,
%
(
arg
.
name
,
arg
.
dimensions
,
"
,
"
.
join
(
self
.
rec
(
idx
,
PREC_NONE
)
"
,
"
.
join
(
self
.
rec
(
idx
,
PREC_NONE
,
'
i
'
)
for
idx
in
expr
.
index
[::
-
1
])))
for
idx
in
expr
.
index
[::
-
1
])))
if
arg
.
dtype
==
np
.
float32
:
if
arg
.
dtype
==
np
.
float32
:
...
@@ -239,10 +284,11 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -239,10 +284,11 @@ class LoopyCCodeMapper(CCodeMapper):
return
"
*
"
+
expr
.
aggregate
.
name
return
"
*
"
+
expr
.
aggregate
.
name
from
pymbolic.primitives
import
Subscript
from
pymbolic.primitives
import
Subscript
return
CCodeMapper
.
map_subscript
(
self
,
return
base_impl
(
Subscript
(
expr
.
aggregate
,
arg
.
offset
+
sum
(
Subscript
(
expr
.
aggregate
,
arg
.
offset
+
sum
(
stride
*
expr_i
for
stride
,
expr_i
in
zip
(
stride
*
expr_i
for
stride
,
expr_i
in
zip
(
ary_strides
,
index_expr
))),
enclosing_prec
)
ary_strides
,
index_expr
))),
enclosing_prec
,
type_context
)
elif
expr
.
aggregate
.
name
in
self
.
kernel
.
temporary_variables
:
elif
expr
.
aggregate
.
name
in
self
.
kernel
.
temporary_variables
:
...
@@ -252,53 +298,68 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -252,53 +298,68 @@ class LoopyCCodeMapper(CCodeMapper):
else
:
else
:
index
=
(
expr
.
index
,)
index
=
(
expr
.
index
,)
return
(
temp_var
.
name
+
""
.
join
(
"
[%s]
"
%
self
.
rec
(
idx
,
PREC_NONE
)
return
(
temp_var
.
name
+
""
.
join
(
"
[%s]
"
%
self
.
rec
(
idx
,
PREC_NONE
,
'
i
'
)
for
idx
in
index
))
for
idx
in
index
))
else
:
else
:
raise
RuntimeError
(
"
nothing known about variable
'
%s
'"
%
expr
.
aggregate
.
name
)
raise
RuntimeError
(
"
nothing known about variable
'
%s
'"
%
expr
.
aggregate
.
name
)
def
map_floor_div
(
self
,
expr
,
prec
):
def
map_floor_div
(
self
,
expr
,
enclosing_prec
,
type_context
):
from
loopy.isl_helpers
import
is_nonnegative
from
loopy.isl_helpers
import
is_nonnegative
num_nonneg
=
is_nonnegative
(
expr
.
numerator
,
self
.
kernel
.
domain
)
num_nonneg
=
is_nonnegative
(
expr
.
numerator
,
self
.
kernel
.
domain
)
den_nonneg
=
is_nonnegative
(
expr
.
denominator
,
self
.
kernel
.
domain
)
den_nonneg
=
is_nonnegative
(
expr
.
denominator
,
self
.
kernel
.
domain
)
if
den_nonneg
:
if
den_nonneg
:
if
num_nonneg
:
if
num_nonneg
:
return
CCodeMapper
.
map_floor_div
(
self
,
expr
,
prec
)
return
self
.
parenthesize_if_needed
(
"
%s // %s
"
%
(
self
.
rec
(
expr
.
numerator
,
PREC_PRODUCT
,
type_context
),
# analogous to ^{-1}
self
.
rec
(
expr
.
denominator
,
PREC_POWER
,
type_context
)),
enclosing_prec
,
PREC_PRODUCT
)
else
:
else
:
return
(
"
int_floor_div_pos_b(%s, %s)
"
return
(
"
int_floor_div_pos_b(%s, %s)
"
%
(
self
.
rec
(
expr
.
numerator
,
PREC_NONE
),
%
(
self
.
rec
(
expr
.
numerator
,
PREC_NONE
,
'
i
'
),
expr
.
denominator
))
self
.
rec
(
expr
.
denominator
,
PREC_NONE
,
'
i
'
)
))
else
:
else
:
return
(
"
int_floor_div(%s, %s)
"
return
(
"
int_floor_div(%s, %s)
"
%
(
self
.
rec
(
expr
.
numerator
,
PREC_NONE
),
%
(
self
.
rec
(
expr
.
numerator
,
PREC_NONE
,
'
i
'
),
self
.
rec
(
expr
.
denominator
,
PREC_NONE
)))
self
.
rec
(
expr
.
denominator
,
PREC_NONE
,
'
i
'
)))
def
map_min
(
self
,
expr
,
prec
):
def
map_min
(
self
,
expr
,
prec
,
type_context
):
what
=
type
(
expr
).
__name__
.
lower
()
what
=
type
(
expr
).
__name__
.
lower
()
children
=
expr
.
children
[:]
children
=
expr
.
children
[:]
result
=
self
.
rec
(
children
.
pop
(),
PREC_NONE
)
result
=
self
.
rec
(
children
.
pop
(),
PREC_NONE
,
type_context
)
while
children
:
while
children
:
result
=
"
%s(%s, %s)
"
%
(
what
,
result
=
"
%s(%s, %s)
"
%
(
what
,
self
.
rec
(
children
.
pop
(),
PREC_NONE
),
self
.
rec
(
children
.
pop
(),
PREC_NONE
,
type_context
),
result
)
result
)
return
result
return
result
map_max
=
map_min
map_max
=
map_min
def
map_constant
(
self
,
expr
,
enclosing_prec
):
def
map_constant
(
self
,
expr
,
enclosing_prec
,
type_context
):
if
isinstance
(
expr
,
complex
):
if
isinstance
(
expr
,
complex
):
# FIXME: type-variable
cast_type
=
"
cdouble_t
"
return
"
(cdouble_t) (%s, %s)
"
%
(
repr
(
expr
.
real
),
repr
(
expr
.
imag
))
if
type_context
==
"
f
"
:
cast_type
=
"
cfloat_t
"
return
"
(%s) (%s, %s)
"
%
(
cast_type
,
repr
(
expr
.
real
),
repr
(
expr
.
imag
))
else
:
else
:
# FIXME: type-variable
if
type_context
==
"
f
"
:
return
repr
(
float
(
expr
))
return
repr
(
float
(
expr
))
+
"
f
"
elif
type_context
==
"
d
"
:
return
repr
(
float
(
expr
))
elif
type_context
==
"
i
"
:
return
str
(
int
(
expr
))
else
:
raise
RuntimeError
(
"
don
'
t know how to generated code
"
"
for constant
'
%s
'"
%
expr
)
def
map_call
(
self
,
expr
,
enclosing_prec
):
def
map_call
(
self
,
expr
,
enclosing_prec
,
type_context
):
from
pymbolic.primitives
import
Variable
from
pymbolic.primitives
import
Variable
from
pymbolic.mapper.stringifier
import
PREC_NONE
from
pymbolic.mapper.stringifier
import
PREC_NONE
...
@@ -311,7 +372,7 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -311,7 +372,7 @@ class LoopyCCodeMapper(CCodeMapper):
par_dtypes
=
tuple
(
self
.
infer_type
(
par
)
for
par
in
expr
.
parameters
)
par_dtypes
=
tuple
(
self
.
infer_type
(
par
)
for
par
in
expr
.
parameters
)
parameters
=
expr
.
parameters
str_
parameters
=
None
mangle_result
=
self
.
kernel
.
mangle_function
(
identifier
,
par_dtypes
)
mangle_result
=
self
.
kernel
.
mangle_function
(
identifier
,
par_dtypes
)
if
mangle_result
is
not
None
:
if
mangle_result
is
not
None
:
...
@@ -320,23 +381,28 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -320,23 +381,28 @@ class LoopyCCodeMapper(CCodeMapper):
elif
len
(
mangle_result
)
==
3
:
elif
len
(
mangle_result
)
==
3
:
result_dtype
,
c_name
,
arg_tgt_dtypes
=
mangle_result
result_dtype
,
c_name
,
arg_tgt_dtypes
=
mangle_result
parameters
=
[
str_parameters
=
[
perform_cast
(
self
,
par
,
par_dtype
,
tgt_dtype
)
self
.
rec
(
perform_cast
(
self
,
par
,
par_dtype
,
tgt_dtype
),
PREC_NONE
,
dtype_to_type_context
(
tgt_dtype
))
for
par
,
par_dtype
,
tgt_dtype
in
zip
(
for
par
,
par_dtype
,
tgt_dtype
in
zip
(
parameters
,
par_dtypes
,
arg_tgt_dtypes
)]
expr
.
parameters
,
par_dtypes
,
arg_tgt_dtypes
)]
else
:
else
:
raise
RuntimeError
(
"
result of function mangler
"
raise
RuntimeError
(
"
result of function mangler
"
"
for function
'
%s
'
not understood
"
"
for function
'
%s
'
not understood
"
%
identifier
)
%
identifier
)
self
.
seen_functions
.
add
((
identifier
,
c_name
,
par_dtypes
))
self
.
seen_functions
.
add
((
identifier
,
c_name
,
par_dtypes
))
if
str_parameters
is
None
:
str_parameters
=
[
self
.
rec
(
par
,
PREC_NONE
,
type_context
)
for
par
in
expr
.
parameters
]
if
c_name
is
None
:
if
c_name
is
None
:
raise
RuntimeError
(
"
unable to find C name for function identifier
'
%s
'"
raise
RuntimeError
(
"
unable to find C name for function identifier
'
%s
'"
%
identifier
)
%
identifier
)
return
self
.
format
(
"
%s(%s)
"
,
return
"
%s(%s)
"
%
(
c_name
,
"
,
"
.
join
(
str_parameters
))
c_name
,
self
.
join_rec
(
"
,
"
,
parameters
,
PREC_NONE
))
# {{{ deal with complex-valued variables
# {{{ deal with complex-valued variables
...
@@ -348,15 +414,22 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -348,15 +414,22 @@ class LoopyCCodeMapper(CCodeMapper):
else
:
else
:
raise
RuntimeError
raise
RuntimeError
def
map_sum
(
self
,
expr
,
enclosing_prec
):
def
map_sum
(
self
,
expr
,
enclosing_prec
,
type_context
):
from
pymbolic.mapper.stringifier
import
PREC_SUM
def
base_impl
(
expr
,
enclosing_prec
,
type_context
):
return
self
.
parenthesize_if_needed
(
self
.
join_rec
(
"
+
"
,
expr
.
children
,
PREC_SUM
,
type_context
),
enclosing_prec
,
PREC_SUM
)
if
not
self
.
allow_complex
:
if
not
self
.
allow_complex
:
return
CCodeMapper
.
map_sum
(
self
,
expr
,
enclosing_prec
)
return
base_impl
(
expr
,
enclosing_prec
,
type_context
)
tgt_dtype
=
self
.
infer_type
(
expr
)
tgt_dtype
=
self
.
infer_type
(
expr
)
is_complex
=
tgt_dtype
.
kind
==
'
c
'
is_complex
=
tgt_dtype
.
kind
==
'
c
'
if
not
is_complex
:
if
not
is_complex
:
return
CCodeMapper
.
map_sum
(
self
,
expr
,
enclosing_prec
)
return
base_impl
(
expr
,
enclosing_prec
,
type_context
)
else
:
else
:
tgt_name
=
self
.
complex_type_name
(
tgt_dtype
)
tgt_name
=
self
.
complex_type_name
(
tgt_dtype
)
...
@@ -365,9 +438,8 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -365,9 +438,8 @@ class LoopyCCodeMapper(CCodeMapper):
complexes
=
[
child
for
child
in
expr
.
children
complexes
=
[
child
for
child
in
expr
.
children
if
'
c
'
==
self
.
infer_type
(
child
).
kind
]
if
'
c
'
==
self
.
infer_type
(
child
).
kind
]
from
pymbolic.mapper.stringifier
import
PREC_SUM
real_sum
=
self
.
join_rec
(
"
+
"
,
reals
,
PREC_SUM
,
type_context
)
real_sum
=
self
.
join_rec
(
"
+
"
,
reals
,
PREC_SUM
)
complex_sum
=
self
.
join_rec
(
"
+
"
,
complexes
,
PREC_SUM
,
type_context
)
complex_sum
=
self
.
join_rec
(
"
+
"
,
complexes
,
PREC_SUM
)
if
real_sum
:
if
real_sum
:
result
=
"
%s_fromreal(%s) + %s
"
%
(
tgt_name
,
real_sum
,
complex_sum
)
result
=
"
%s_fromreal(%s) + %s
"
%
(
tgt_name
,
real_sum
,
complex_sum
)
...
@@ -376,15 +448,22 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -376,15 +448,22 @@ class LoopyCCodeMapper(CCodeMapper):
return
self
.
parenthesize_if_needed
(
result
,
enclosing_prec
,
PREC_SUM
)
return
self
.
parenthesize_if_needed
(
result
,
enclosing_prec
,
PREC_SUM
)
def
map_product
(
self
,
expr
,
enclosing_prec
):
def
map_product
(
self
,
expr
,
enclosing_prec
,
type_context
):
def
base_impl
(
expr
,
enclosing_prec
,
type_context
):
# Spaces prevent '**z' (times dereference z), which
# is hard to read.
return
self
.
parenthesize_if_needed
(
self
.
join_rec
(
"
*
"
,
expr
.
children
,
PREC_PRODUCT
,
type_context
),
enclosing_prec
,
PREC_PRODUCT
)
if
not
self
.
allow_complex
:
if
not
self
.
allow_complex
:
return
CCodeMapper
.
map_product
(
self
,
expr
,
enclosing_prec
)
return
base_impl
(
expr
,
enclosing_prec
,
type_context
)
tgt_dtype
=
self
.
infer_type
(
expr
)
tgt_dtype
=
self
.
infer_type
(
expr
)
is_complex
=
'
c
'
==
tgt_dtype
.
kind
is_complex
=
'
c
'
==
tgt_dtype
.
kind
if
not
is_complex
:
if
not
is_complex
:
return
CCodeMapper
.
map_product
(
self
,
expr
,
enclosing_prec
)
return
base_impl
(
expr
,
enclosing_prec
,
type_context
)
else
:
else
:
tgt_name
=
self
.
complex_type_name
(
tgt_dtype
)
tgt_name
=
self
.
complex_type_name
(
tgt_dtype
)
...
@@ -393,19 +472,18 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -393,19 +472,18 @@ class LoopyCCodeMapper(CCodeMapper):
complexes
=
[
child
for
child
in
expr
.
children
complexes
=
[
child
for
child
in
expr
.
children
if
'
c
'
==
self
.
infer_type
(
child
).
kind
]
if
'
c
'
==
self
.
infer_type
(
child
).
kind
]
from
pymbolic.mapper.stringifier
import
PREC_PRODUCT
real_prd
=
self
.
join_rec
(
"
*
"
,
reals
,
PREC_PRODUCT
,
type_context
)
real_prd
=
self
.
join_rec
(
"
*
"
,
reals
,
PREC_PRODUCT
)
if
len
(
complexes
)
==
1
:
if
len
(
complexes
)
==
1
:
myprec
=
PREC_PRODUCT
myprec
=
PREC_PRODUCT
else
:
else
:
myprec
=
PREC_NONE
myprec
=
PREC_NONE
complex_prd
=
self
.
rec
(
complexes
[
0
],
myprec
)
complex_prd
=
self
.
rec
(
complexes
[
0
],
myprec
,
type_context
)
for
child
in
complexes
[
1
:]:
for
child
in
complexes
[
1
:]:
complex_prd
=
"
%s_mul(%s, %s)
"
%
(
complex_prd
=
"
%s_mul(%s, %s)
"
%
(
tgt_name
,
complex_prd
,
tgt_name
,
complex_prd
,
self
.
rec
(
child
,
PREC_NONE
))
self
.
rec
(
child
,
PREC_NONE
,
type_context
))
if
real_prd
:
if
real_prd
:
# elementwise semantics are correct
# elementwise semantics are correct
...
@@ -415,9 +493,19 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -415,9 +493,19 @@ class LoopyCCodeMapper(CCodeMapper):
return
self
.
parenthesize_if_needed
(
result
,
enclosing_prec
,
PREC_PRODUCT
)
return
self
.
parenthesize_if_needed
(
result
,
enclosing_prec
,
PREC_PRODUCT
)
def
map_quotient
(
self
,
expr
,
enclosing_prec
):
def
map_quotient
(
self
,
expr
,
enclosing_prec
,
type_context
):
def
base_impl
(
expr
,
enclosing_prec
,
type_context
):
return
self
.
parenthesize_if_needed
(
"
%s / %s
"
%
(
# space is necessary--otherwise '/*' becomes
# start-of-comment in C.
self
.
rec
(
expr
.
numerator
,
PREC_PRODUCT
,
type_context
),
# analogous to ^{-1}
self
.
rec
(
expr
.
denominator
,
PREC_POWER
,
type_context
)),
enclosing_prec
,
PREC_PRODUCT
)
if
not
self
.
allow_complex
:
if
not
self
.
allow_complex
:
return
CCodeMapper
.
map_quotient
(
self
,
expr
,
enclosing_prec
)
return
base_impl
(
expr
,
enclosing_prec
,
type_context
)
n_complex
=
'
c
'
==
self
.
infer_type
(
expr
.
numerator
).
kind
n_complex
=
'
c
'
==
self
.
infer_type
(
expr
.
numerator
).
kind
d_complex
=
'
c
'
==
self
.
infer_type
(
expr
.
denominator
).
kind
d_complex
=
'
c
'
==
self
.
infer_type
(
expr
.
denominator
).
kind
...
@@ -425,36 +513,48 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -425,36 +513,48 @@ class LoopyCCodeMapper(CCodeMapper):
tgt_dtype
=
self
.
infer_type
(
expr
)
tgt_dtype
=
self
.
infer_type
(
expr
)
if
not
(
n_complex
or
d_complex
):
if
not
(
n_complex
or
d_complex
):
return
CCodeMapper
.
map_quotient
(
self
,
expr
,
enclosing_prec
)
return
base_impl
(
expr
,
enclosing_prec
,
type_context
)
elif
n_complex
and
not
d_complex
:
elif
n_complex
and
not
d_complex
:
# elementwise semnatics are correct
# elementwise semnatics are correct
return
CCodeMapper
.
map_quotient
(
self
,
expr
,
enclosing_prec
)
return
base_impl
(
expr
,
enclosing_prec
,
type_context
)
elif
not
n_complex
and
d_complex
:
elif
not
n_complex
and
d_complex
:
return
"
%s_rdivide(%s, %s)
"
%
(
return
"
%s_rdivide(%s, %s)
"
%
(
self
.
complex_type_name
(
tgt_dtype
),
self
.
complex_type_name
(
tgt_dtype
),
self
.
rec
(
expr
.
numerator
,
PREC_NONE
),
self
.
rec
(
expr
.
numerator
,
PREC_NONE
,
type_context
),
self
.
rec
(
expr
.
denominator
,
PREC_NONE
))
self
.
rec
(
expr
.
denominator
,
PREC_NONE
,
type_context
))
else
:
else
:
return
"
%s_divide(%s, %s)
"
%
(
return
"
%s_divide(%s, %s)
"
%
(
self
.
complex_type_name
(
tgt_dtype
),
self
.
complex_type_name
(
tgt_dtype
),
self
.
rec
(
expr
.
numerator
,
PREC_NONE
),
self
.
rec
(
expr
.
numerator
,
PREC_NONE
,
type_context
),
self
.
rec
(
expr
.
denominator
,
PREC_NONE
))
self
.
rec
(
expr
.
denominator
,
PREC_NONE
,
type_context
))
def
map_remainder
(
self
,
expr
,
enclosing_prec
):
if
not
self
.
allow_complex
:
return
CCodeMapper
.
map_remainder
(
self
,
expr
,
enclosing_prec
)
def
map_remainder
(
self
,
expr
,
enclosing_prec
,
type_context
):
tgt_dtype
=
self
.
infer_type
(
expr
)
tgt_dtype
=
self
.
infer_type
(
expr
)
if
'
c
'
==
tgt_dtype
.
kind
:
if
'
c
'
==
tgt_dtype
.
kind
:
raise
RuntimeError
(
"
complex remainder not defined
"
)
raise
RuntimeError
(
"
complex remainder not defined
"
)
return
CCodeMapper
.
map_remainder
(
self
,
expr
,
enclosing_prec
)
return
"
(%s %% %s)
"
%
(
self
.
rec
(
expr
.
numerator
,
PREC_PRODUCT
,
type_context
),
self
.
rec
(
expr
.
denominator
,
PREC_POWER
,
type_context
))
# analogous to ^{-1}
def
map_power
(
self
,
expr
,
enclosing_prec
,
type_context
):
def
base_impl
(
expr
,
enclosing_prec
,
type_context
):
from
pymbolic.mapper.stringifier
import
PREC_NONE
from
pymbolic.primitives
import
is_constant
,
is_zero
if
is_constant
(
expr
.
exponent
):
if
is_zero
(
expr
.
exponent
):
return
"
1
"
elif
is_zero
(
expr
.
exponent
-
1
):
return
self
.
rec
(
expr
.
base
,
enclosing_prec
,
type_context
)
elif
is_zero
(
expr
.
exponent
-
2
):
return
self
.
rec
(
expr
.
base
*
expr
.
base
,
enclosing_prec
,
type_context
)
return
"
pow(%s, %s)
"
%
(
self
.
rec
(
expr
.
base
,
PREC_NONE
,
type_context
),
self
.
rec
(
expr
.
exponent
,
PREC_NONE
,
type_context
))
def
map_power
(
self
,
expr
,
enclosing_prec
):
if
not
self
.
allow_complex
:
if
not
self
.
allow_complex
:
return
CCodeMapper
.
map_power
(
self
,
expr
,
enclosing_prec
)
return
base_impl
(
expr
,
enclosing_prec
,
type_context
)
from
pymbolic.mapper.stringifier
import
PREC_NONE
tgt_dtype
=
self
.
infer_type
(
expr
)
tgt_dtype
=
self
.
infer_type
(
expr
)
if
'
c
'
==
tgt_dtype
.
kind
:
if
'
c
'
==
tgt_dtype
.
kind
:
...
@@ -462,7 +562,7 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -462,7 +562,7 @@ class LoopyCCodeMapper(CCodeMapper):
value
=
expr
.
base
value
=
expr
.
base
for
i
in
range
(
expr
.
exponent
-
1
):
for
i
in
range
(
expr
.
exponent
-
1
):
value
=
value
*
expr
.
base
value
=
value
*
expr
.
base
return
self
.
rec
(
value
,
enclosing_prec
)
return
self
.
rec
(
value
,
enclosing_prec
,
type_context
)
else
:
else
:
b_complex
=
'
c
'
==
self
.
infer_type
(
expr
.
base
).
kind
b_complex
=
'
c
'
==
self
.
infer_type
(
expr
.
base
).
kind
e_complex
=
'
c
'
==
self
.
infer_type
(
expr
.
exponent
).
kind
e_complex
=
'
c
'
==
self
.
infer_type
(
expr
.
exponent
).
kind
...
@@ -470,18 +570,22 @@ class LoopyCCodeMapper(CCodeMapper):
...
@@ -470,18 +570,22 @@ class LoopyCCodeMapper(CCodeMapper):
if
b_complex
and
not
e_complex
:
if
b_complex
and
not
e_complex
:
return
"
%s_powr(%s, %s)
"
%
(
return
"
%s_powr(%s, %s)
"
%
(
self
.
complex_type_name
(
tgt_dtype
),
self
.
complex_type_name
(
tgt_dtype
),
self
.
rec
(
expr
.
base
,
PREC_NONE
),
self
.
rec
(
expr
.
base
,
PREC_NONE
,
type_context
),
self
.
rec
(
expr
.
exponent
,
PREC_NONE
))
self
.
rec
(
expr
.
exponent
,
PREC_NONE
,
type_context
))
else
:
else
:
return
"
%s_pow(%s, %s)
"
%
(
return
"
%s_pow(%s, %s)
"
%
(
self
.
complex_type_name
(
tgt_dtype
),
self
.
complex_type_name
(
tgt_dtype
),
self
.
rec
(
expr
.
base
,
PREC_NONE
),
self
.
rec
(
expr
.
base
,
PREC_NONE
,
type_context
),
self
.
rec
(
expr
.
exponent
,
PREC_NONE
))
self
.
rec
(
expr
.
exponent
,
PREC_NONE
,
type_context
))
return
CCodeMapper
.
map_power
(
self
,
expr
,
enclosing_prec
)
return
base_impl
(
self
,
expr
,
enclosing_prec
,
type_context
)
# }}}
# }}}
def
__call__
(
self
,
expr
,
type_context
,
prec
=
PREC_NONE
):
from
pymbolic.mapper
import
RecursiveMapper
return
RecursiveMapper
.
__call__
(
self
,
expr
,
prec
,
type_context
)
# }}}
# }}}
# vim: fdm=marker
# vim: fdm=marker
This diff is collapsed.
Click to expand it.
loopy/codegen/instruction.py
+
10
−
3
View file @
91be1b94
...
@@ -12,11 +12,18 @@ def generate_instruction_code(kernel, insn, codegen_state):
...
@@ -12,11 +12,18 @@ def generate_instruction_code(kernel, insn, codegen_state):
expr
=
insn
.
expression
expr
=
insn
.
expression
from
loopy.codegen.expression
import
perform_cast
from
loopy.codegen.expression
import
perform_cast
expr
=
perform_cast
(
ccm
,
expr
,
expr_dtype
=
ccm
.
infer_type
(
expr
),
target_dtype
=
kernel
.
get_var_descriptor
(
insn
.
get_assignee_var_name
()).
dtype
target_dtype
=
kernel
.
get_var_descriptor
(
insn
.
get_assignee_var_name
()).
dtype
)
expr_dtype
=
ccm
.
infer_type
(
expr
)
expr
=
perform_cast
(
ccm
,
expr
,
expr_dtype
=
expr_dtype
,
target_dtype
=
target_dtype
)
from
cgen
import
Assign
from
cgen
import
Assign
insn_code
=
Assign
(
ccm
(
insn
.
assignee
),
ccm
(
expr
))
from
loopy.codegen.expression
import
dtype_to_type_context
insn_code
=
Assign
(
ccm
(
insn
.
assignee
,
prec
=
None
,
type_context
=
None
),
ccm
(
expr
,
prec
=
None
,
type_context
=
dtype_to_type_context
(
target_dtype
)))
from
loopy.codegen.bounds
import
wrap_in_bounds_checks
from
loopy.codegen.bounds
import
wrap_in_bounds_checks
insn_inames
=
kernel
.
insn_inames
(
insn
)
insn_inames
=
kernel
.
insn_inames
(
insn
)
insn_code
,
impl_domain
=
wrap_in_bounds_checks
(
insn_code
,
impl_domain
=
wrap_in_bounds_checks
(
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment