Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
pymbolic
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Lawrence Mitchell
pymbolic
Commits
ca628341
Commit
ca628341
authored
13 years ago
by
Andreas Klöckner
Browse files
Options
Downloads
Patches
Plain Diff
Various CSE fixes.
parent
a911877b
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
pymbolic/cse.py
+36
-18
36 additions, 18 deletions
pymbolic/cse.py
pymbolic/primitives.py
+20
-0
20 additions, 0 deletions
pymbolic/primitives.py
pymbolic/sympy_conv.py
+0
-2
0 additions, 2 deletions
pymbolic/sympy_conv.py
with
56 additions
and
20 deletions
pymbolic/cse.py
+
36
−
18
View file @
ca628341
from
__future__
import
division
from
__future__
import
division
import
pymbolic.primitives
as
prim
import
pymbolic.primitives
as
prim
from
pymbolic.mapper
import
IdentityMapper
,
WalkMapper
from
pymbolic.mapper
import
IdentityMapper
,
WalkMapper
from
pytools
import
memoize_method
COMMUTATIVE_CLASSES
=
(
prim
.
Sum
,
prim
.
Product
)
COMMUTATIVE_CLASSES
=
(
prim
.
Sum
,
prim
.
Product
)
def
get_normalized_cse_key
(
node
):
class
CSERemover
(
IdentityMapper
):
if
isinstance
(
node
,
COMMUTATIVE_CLASSES
):
def
map_common_subexpression
(
self
,
expr
):
return
type
(
node
),
frozenset
(
node
.
children
)
return
self
.
rec
(
expr
.
child
)
else
:
return
node
class
NormalizedKeyGetter
(
object
):
def
__init__
(
self
):
self
.
cse_remover
=
CSERemover
()
@memoize_method
def
remove_cses
(
self
,
expr
):
return
self
.
cse_remover
(
expr
)
def
__call__
(
self
,
expr
):
expr
=
self
.
remove_cses
(
expr
)
if
isinstance
(
expr
,
COMMUTATIVE_CLASSES
):
return
type
(
expr
),
frozenset
(
expr
.
children
)
else
:
return
expr
class
CSEMapper
(
IdentityMapper
):
class
CSEMapper
(
IdentityMapper
):
def
__init__
(
self
,
to_eliminate
):
def
__init__
(
self
,
to_eliminate
,
get_key
):
self
.
to_eliminate
=
to_eliminate
self
.
to_eliminate
=
to_eliminate
self
.
get_key
=
get_key
self
.
canonical_subexprs
=
{}
self
.
canonical_subexprs
=
{}
def
get_cse
(
self
,
expr
,
key
=
None
):
def
get_cse
(
self
,
expr
,
key
=
None
):
if
key
is
None
:
if
key
is
None
:
key
=
get_normalized_cse
_key
(
expr
)
key
=
self
.
get
_key
(
expr
)
try
:
try
:
return
self
.
canonical_subexprs
[
key
]
return
self
.
canonical_subexprs
[
key
]
except
KeyError
:
except
KeyError
:
new_expr
=
prim
.
CommonSubexpression
(
new_expr
=
prim
.
wrap_in_cse
(
getattr
(
IdentityMapper
,
expr
.
mapper_method
)(
self
,
expr
))
getattr
(
IdentityMapper
,
expr
.
mapper_method
)(
self
,
expr
))
self
.
canonical_subexprs
[
key
]
=
new_expr
self
.
canonical_subexprs
[
key
]
=
new_expr
return
new_expr
return
new_expr
def
map_sum
(
self
,
expr
):
def
map_sum
(
self
,
expr
):
key
=
get_normalized_cse
_key
(
expr
)
key
=
self
.
get
_key
(
expr
)
if
key
in
self
.
to_eliminate
:
if
key
in
self
.
to_eliminate
:
result
=
self
.
get_cse
(
expr
,
key
)
result
=
self
.
get_cse
(
expr
,
key
)
return
result
return
result
...
@@ -49,11 +67,9 @@ class CSEMapper(IdentityMapper):
...
@@ -49,11 +67,9 @@ class CSEMapper(IdentityMapper):
map_floor_div
=
map_sum
map_floor_div
=
map_sum
map_call
=
map_sum
map_call
=
map_sum
def
map_quotient
(
self
,
expr
):
def
map_common_subexpression
(
self
,
expr
):
if
expr
in
self
.
to_eliminate
:
# don't duplicate CSEs
return
self
.
get_cse
(
expr
)
return
prim
.
wrap_in_cse
(
self
.
rec
(
expr
.
child
),
expr
.
prefix
)
else
:
return
IdentityMapper
.
map_quotient
(
self
,
expr
)
def
map_substitution
(
self
,
expr
):
def
map_substitution
(
self
,
expr
):
return
type
(
expr
)(
return
type
(
expr
)(
...
@@ -65,11 +81,12 @@ class CSEMapper(IdentityMapper):
...
@@ -65,11 +81,12 @@ class CSEMapper(IdentityMapper):
class
UseCountMapper
(
WalkMapper
):
class
UseCountMapper
(
WalkMapper
):
def
__init__
(
self
):
def
__init__
(
self
,
get_key
):
self
.
subexpr_counts
=
{}
self
.
subexpr_counts
=
{}
self
.
get_key
=
get_key
def
visit
(
self
,
expr
):
def
visit
(
self
,
expr
):
key
=
get_normalized_cse
_key
(
expr
)
key
=
self
.
get
_key
(
expr
)
if
key
in
self
.
subexpr_counts
:
if
key
in
self
.
subexpr_counts
:
self
.
subexpr_counts
[
key
]
+=
1
self
.
subexpr_counts
[
key
]
+=
1
...
@@ -86,7 +103,8 @@ class UseCountMapper(WalkMapper):
...
@@ -86,7 +103,8 @@ class UseCountMapper(WalkMapper):
def
tag_common_subexpressions
(
exprs
):
def
tag_common_subexpressions
(
exprs
):
ucm
=
UseCountMapper
()
get_key
=
NormalizedKeyGetter
()
ucm
=
UseCountMapper
(
get_key
)
if
isinstance
(
exprs
,
prim
.
Expression
):
if
isinstance
(
exprs
,
prim
.
Expression
):
raise
TypeError
(
"
exprs should be an iterable of expressions
"
)
raise
TypeError
(
"
exprs should be an iterable of expressions
"
)
...
@@ -97,7 +115,7 @@ def tag_common_subexpressions(exprs):
...
@@ -97,7 +115,7 @@ def tag_common_subexpressions(exprs):
to_eliminate
=
set
([
subexpr_key
to_eliminate
=
set
([
subexpr_key
for
subexpr_key
,
count
in
ucm
.
subexpr_counts
.
iteritems
()
for
subexpr_key
,
count
in
ucm
.
subexpr_counts
.
iteritems
()
if
count
>
1
])
if
count
>
1
])
cse_mapper
=
CSEMapper
(
to_eliminate
)
cse_mapper
=
CSEMapper
(
to_eliminate
,
get_key
)
result
=
[
cse_mapper
(
expr
)
for
expr
in
exprs
]
result
=
[
cse_mapper
(
expr
)
for
expr
in
exprs
]
return
result
return
result
This diff is collapsed.
Click to expand it.
pymbolic/primitives.py
+
20
−
0
View file @
ca628341
...
@@ -868,6 +868,26 @@ def is_zero(value):
...
@@ -868,6 +868,26 @@ def is_zero(value):
def
wrap_in_cse
(
expr
,
prefix
=
None
):
if
isinstance
(
expr
,
Variable
):
return
expr
if
isinstance
(
expr
,
CommonSubexpression
):
if
prefix
is
None
:
return
expr
if
expr
.
prefix
is
None
:
return
CommonSubexpression
(
expr
.
child
,
prefix
)
# existing prefix wins
return
expr
else
:
return
CommonSubexpression
(
expr
,
prefix
)
def
make_common_subexpression
(
field
,
prefix
=
None
):
def
make_common_subexpression
(
field
,
prefix
=
None
):
try
:
try
:
from
pytools.obj_array
import
log_shape
from
pytools.obj_array
import
log_shape
...
...
This diff is collapsed.
Click to expand it.
pymbolic/sympy_conv.py
+
0
−
2
View file @
ca628341
...
@@ -53,8 +53,6 @@ class ToPymbolicMapper(_SympyMapper):
...
@@ -53,8 +53,6 @@ class ToPymbolicMapper(_SympyMapper):
if
prim
.
is_zero
(
denom
-
1
):
if
prim
.
is_zero
(
denom
-
1
):
return
num
return
num
if
isinstance
(
num
,
int
)
and
isinstance
(
denom
,
int
):
return
int
(
num
)
/
int
(
denom
)
return
prim
.
Quotient
(
num
,
denom
)
return
prim
.
Quotient
(
num
,
denom
)
def
map_Pow
(
self
,
expr
):
def
map_Pow
(
self
,
expr
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment