Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
L
loopy
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Kaushik Kulkarni
loopy
Commits
73b3ab13
Commit
73b3ab13
authored
9 years ago
by
Andreas Klöckner
Browse files
Options
Downloads
Patches
Plain Diff
Working version of the distributive law transform
parent
dfd763c2
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
loopy/transform/arithmetic.py
+166
-58
166 additions, 58 deletions
loopy/transform/arithmetic.py
with
166 additions
and
58 deletions
loopy/transform/arithmetic.py
+
166
−
58
View file @
73b3ab13
...
...
@@ -152,28 +152,106 @@ def fold_constants(kernel):
# {{{ collect_common_factors_on_increment
# thus far undocumented
def
collect_common_factors_on_increment
(
kernel
,
var_name
,
is_index_specific
=
False
):
def
collect_common_factors_on_increment
(
kernel
,
var_name
,
vary_by_axes
=
()
):
# FIXME: Does not understand subst rules for now
if
kernel
.
substitutions
:
from
loopy.transform.subst
import
expand_subst
kernel
=
expand_subst
(
kernel
)
if
var_name
in
kernel
.
temporary_variables
:
var_descr
=
kernel
.
temporary_variables
[
var_name
]
elif
var_name
in
kernel
.
arg_dict
:
var_descr
=
kernel
.
arg_dict
[
var_name
]
else
:
raise
NameError
(
"
array
'
%s
'
was not found
"
%
var_name
)
# {{{ check/normalize vary_by_axes
if
isinstance
(
vary_by_axes
,
str
):
vary_by_axes
=
vary_by_axes
.
split
(
"
,
"
)
from
loopy.kernel.array
import
ArrayBase
if
isinstance
(
var_descr
,
ArrayBase
):
if
var_descr
.
dim_names
is
not
None
:
name_to_index
=
dict
(
(
name
,
idx
)
for
idx
,
name
in
enumerate
(
var_descr
.
dim_names
))
else
:
name_to_index
=
{}
def
map_ax_name_to_index
(
ax
):
if
isinstance
(
ax
,
str
):
try
:
return
name_to_index
[
ax
]
except
KeyError
:
raise
LoopyError
(
"
axis name
'
%s
'
not understood
"
%
ax
)
else
:
return
ax
vary_by_axes
=
[
map_ax_name_to_index
(
ax
)
for
ax
in
vary_by_axes
]
if
(
vary_by_axes
and
(
min
(
vary_by_axes
)
<
0
or
max
(
vary_by_axes
)
>
var_descr
.
num_user_axes
())):
raise
LoopyError
(
"
vary_by_axes refers to out-of-bounds axis index
"
)
# }}}
from
pymbolic.mapper.substitutor
import
make_subst_func
from
pymbolic.primitives
import
(
Sum
,
Product
,
is_zero
,
flattened_sum
,
flattened_product
,
Subscript
,
Variable
)
from
loopy.symbolic
import
get_dependencies
,
SubstitutionMapper
from
loopy.symbolic
import
(
get_dependencies
,
SubstitutionMapper
,
UnidirectionalUnifier
)
# {{{
find
common factor
s
# {{{ common factor
key list maintenance
#
maps lhs indices (or N
on
e
f
or is_index_specific
)
common_factors
=
{}
#
list of (index_key, comm
on f
actors found
)
common_factors
=
[]
from
loopy.kernel.data
import
ExpressionInstruction
def
find_unifiable_cf_index
(
index_key
):
for
i
,
(
key
,
val
)
in
enumerate
(
common_factors
):
unif
=
UnidirectionalUnifier
(
lhs_mapping_candidates
=
get_dependencies
(
key
))
unif_result
=
unif
(
key
,
index_key
)
if
unif_result
:
assert
len
(
unif_result
)
==
1
return
i
,
unif_result
[
0
]
return
None
,
None
def
extract_index_key
(
access_expr
):
if
isinstance
(
access_expr
,
Variable
):
return
()
elif
isinstance
(
access_expr
,
Subscript
):
index
=
access_expr
.
index_tuple
return
tuple
(
index
[
ax
]
for
ax
in
vary_by_axes
)
else
:
raise
ValueError
(
"
unexpected type of access_expr
"
)
def
is_assignee
(
insn
):
return
any
(
lhs
==
var_name
for
lhs
,
sbscript
in
insn
.
assignees_and_indices
())
def
iterate_as
(
cls
,
expr
):
if
isinstance
(
expr
,
cls
):
for
ch
in
expr
.
children
:
yield
ch
else
:
yield
expr
# }}}
# {{{ find common factors
from
loopy.kernel.data
import
ExpressionInstruction
for
insn
in
kernel
.
instructions
:
if
not
is_assignee
(
insn
):
continue
...
...
@@ -182,46 +260,70 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals
raise
LoopyError
(
"'
%s
'
modified by non-expression instruction
"
%
var_name
)
(
_
,
index_key
),
=
insn
.
assignees_and_indices
()
if
not
is_index_specific
:
index_key
=
None
lhs
=
insn
.
assignee
rhs
=
insn
.
expression
if
is_zero
(
rhs
):
continue
if
isinstance
(
rhs
,
Sum
):
sum_terms
=
rhs
.
children
index_key
=
extract_index_key
(
lhs
)
cf_index
,
unif_result
=
find_unifiable_cf_index
(
index_key
)
if
cf_index
is
None
:
# {{{ doesn't exist yet
assert
unif_result
is
None
my_common_factors
=
None
for
term
in
iterate_as
(
Sum
,
rhs
):
if
term
==
lhs
:
continue
for
part
in
iterate_as
(
Product
,
term
):
if
var_name
in
get_dependencies
(
part
):
raise
LoopyError
(
"
unexpected dependency on
'
%s
'
"
"
in RHS of instruction
'
%s
'"
%
(
var_name
,
insn
.
id
))
product_parts
=
set
(
iterate_as
(
Product
,
term
))
if
my_common_factors
is
None
:
my_common_factors
=
product_parts
else
:
my_common_factors
=
my_common_factors
&
product_parts
if
my_common_factors
is
not
None
:
common_factors
.
append
((
index_key
,
my_common_factors
))
# }}}
else
:
sum_terms
=
[
rhs
]
# {{{ match, filter existing common factors
my_common_factors
=
common_factors
.
get
(
index_key
)
_
,
my_common_factors
=
common_factors
[
cf_index
]
for
term
in
sum_terms
:
if
term
==
lhs
:
continue
unif_subst_map
=
SubstitutionMapper
(
make_subst_func
(
unif_result
.
lmap
))
if
isinstance
(
term
,
Product
):
product_parts
=
set
(
term
.
children
)
else
:
product_parts
=
set
([
term
])
for
term
in
iterate_as
(
Sum
,
rhs
):
if
term
==
lhs
:
continue
for
part
in
product_parts
:
if
var_name
in
get_dependencies
(
part
):
raise
LoopyError
(
"
unexpected dependency on
'
%s
'
"
"
in RHS of instruction
'
%s
'"
%
(
var_name
,
insn
.
id
))
for
part
in
iterate_as
(
Product
,
term
)
:
if
var_name
in
get_dependencies
(
part
):
raise
LoopyError
(
"
unexpected dependency on
'
%s
'
"
"
in RHS of instruction
'
%s
'"
%
(
var_name
,
insn
.
id
))
if
my_common_factors
is
None
:
my_common_factors
=
product_parts
else
:
my_common_factors
=
my_common_factors
&
product_parts
product_parts
=
set
(
iterate_as
(
Product
,
term
))
if
my_common_factors
is
not
None
:
common_factors
[
index_key
]
=
my_common_factors
my_common_factors
=
set
(
cf
for
cf
in
my_common_factors
if
unif_subst_map
(
cf
)
in
product_parts
)
common_factors
[
cf_index
]
=
(
index_key
,
my_common_factors
)
# }}}
# }}}
...
...
@@ -236,9 +338,6 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals
(
_
,
index_key
),
=
insn
.
assignees_and_indices
()
if
not
is_index_specific
:
index_key
=
None
lhs
=
insn
.
assignee
rhs
=
insn
.
expression
...
...
@@ -246,29 +345,34 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals
new_insns
.
append
(
insn
)
continue
if
isinstance
(
rhs
,
Sum
):
sum_terms
=
rhs
.
children
else
:
sum_terms
=
[
rhs
]
index_key
=
extract_index_key
(
lhs
)
cf_index
,
unif_result
=
find_unifiable_cf_index
(
index_key
)
if
cf_index
is
None
:
new_insns
.
append
(
insn
)
continue
_
,
my_common_factors
=
common_factors
[
cf_index
]
unif_subst_map
=
SubstitutionMapper
(
make_subst_func
(
unif_result
.
lmap
))
mapped_my_common_factors
=
set
(
unif_subst_map
(
cf
)
for
cf
in
my_common_factors
)
my_common_factors
=
common_factors
.
get
(
index_key
)
new_sum_terms
=
[]
for
term
in
sum_terms
:
for
term
in
iterate_as
(
Sum
,
rhs
)
:
if
term
==
lhs
:
new_sum_terms
.
append
(
term
)
continue
if
isinstance
(
term
,
Product
):
product_parts
=
term
.
children
else
:
product_parts
=
[
term
]
new_sum_terms
.
append
(
flattened_product
([
part
for
part
in
product_parts
if
part
not
in
my_common_factors
for
part
in
iterate_as
(
Product
,
term
)
if
part
not
in
mapped_
my_common_factors
]))
new_insns
.
append
(
...
...
@@ -280,23 +384,27 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals
def
find_substitution
(
expr
):
if
isinstance
(
expr
,
Subscript
):
if
is_index_specific
:
index_key
=
expr
.
index
else
:
index_key
=
None
v
=
expr
.
aggregate
.
name
elif
isinstance
(
expr
,
Variable
):
v
=
expr
.
name
else
:
return
None
return
expr
if
v
!=
var_name
:
return
None
return
expr
index_key
=
extract_index_key
(
expr
)
cf_index
,
unif_result
=
find_unifiable_cf_index
(
index_key
)
unif_subst_map
=
SubstitutionMapper
(
make_subst_func
(
unif_result
.
lmap
))
_
,
my_common_factors
=
common_factors
[
cf_index
]
my_common_factors
=
common_factors
.
get
(
index_key
)
if
my_common_factors
is
not
None
:
return
flattened_product
(
list
(
my_common_factors
)
+
[
expr
])
return
flattened_product
(
[
unif_subst_map
(
cf
)
for
cf
in
my_common_factors
]
+
[
expr
])
else
:
return
expr
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment