Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
L
loopy
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Ben Sepanski
loopy
Commits
a336b851
Commit
a336b851
authored
9 years ago
by
Andreas Klöckner
Browse files
Options
Downloads
Plain Diff
Merge bodge:src/loopy
parents
4c292b1e
b99c1622
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
loopy/buffer.py
+28
-0
28 additions, 0 deletions
loopy/buffer.py
loopy/context_matching.py
+78
-1
78 additions, 1 deletion
loopy/context_matching.py
with
106 additions
and
1 deletion
loopy/buffer.py
+
28
−
0
View file @
a336b851
...
@@ -29,6 +29,9 @@ from loopy.symbolic import (get_dependencies,
...
@@ -29,6 +29,9 @@ from loopy.symbolic import (get_dependencies,
RuleAwareIdentityMapper
,
SubstitutionRuleMappingContext
,
RuleAwareIdentityMapper
,
SubstitutionRuleMappingContext
,
SubstitutionMapper
)
SubstitutionMapper
)
from
pymbolic.mapper.substitutor
import
make_subst_func
from
pymbolic.mapper.substitutor
import
make_subst_func
from
pytools.persistent_dict
import
PersistentDict
from
loopy.tools
import
LoopyKeyBuilder
from
loopy.version
import
DATA_MODEL_VERSION
from
pymbolic
import
var
from
pymbolic
import
var
...
@@ -117,6 +120,11 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper):
...
@@ -117,6 +120,11 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper):
# }}}
# }}}
buffer_array_cache
=
PersistentDict
(
"
loopy-buffer-array-cachee
"
+
DATA_MODEL_VERSION
,
key_builder
=
LoopyKeyBuilder
())
# Adding an argument? also add something to the cache_key below.
def
buffer_array
(
kernel
,
var_name
,
buffer_inames
,
init_expression
=
None
,
def
buffer_array
(
kernel
,
var_name
,
buffer_inames
,
init_expression
=
None
,
store_expression
=
None
,
within
=
None
,
default_tag
=
"
l.auto
"
,
store_expression
=
None
,
within
=
None
,
default_tag
=
"
l.auto
"
,
temporary_is_local
=
None
,
fetch_bounding_box
=
False
):
temporary_is_local
=
None
,
fetch_bounding_box
=
False
):
...
@@ -173,6 +181,22 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
...
@@ -173,6 +181,22 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
# }}}
# }}}
# {{{ caching
from
loopy
import
CACHING_ENABLED
cache_key
=
(
kernel
,
var_name
,
tuple
(
buffer_inames
),
init_expression
,
store_expression
,
within
,
default_tag
,
temporary_is_local
,
fetch_bounding_box
)
if
CACHING_ENABLED
:
try
:
return
buffer_array_cache
[
cache_key
]
except
KeyError
:
pass
# }}}
var_name_gen
=
kernel
.
get_var_name_generator
()
var_name_gen
=
kernel
.
get_var_name_generator
()
within_inames
=
set
()
within_inames
=
set
()
...
@@ -413,6 +437,10 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
...
@@ -413,6 +437,10 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
from
loopy
import
tag_inames
from
loopy
import
tag_inames
kernel
=
tag_inames
(
kernel
,
new_iname_to_tag
)
kernel
=
tag_inames
(
kernel
,
new_iname_to_tag
)
if
0
and
CACHING_ENABLED
:
from
loopy.preprocess
import
prepare_for_caching
buffer_array_cache
[
cache_key
]
=
prepare_for_caching
(
kernel
)
return
kernel
return
kernel
# vim: foldmethod=marker
# vim: foldmethod=marker
This diff is collapsed.
Click to expand it.
loopy/context_matching.py
+
78
−
1
View file @
a336b851
...
@@ -94,11 +94,21 @@ class MatchExpressionBase(object):
...
@@ -94,11 +94,21 @@ class MatchExpressionBase(object):
def
__call__
(
self
,
kernel
,
matchable
):
def
__call__
(
self
,
kernel
,
matchable
):
raise
NotImplementedError
raise
NotImplementedError
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
class
AllMatchExpression
(
MatchExpressionBase
):
class
AllMatchExpression
(
MatchExpressionBase
):
def
__call__
(
self
,
kernel
,
matchable
):
def
__call__
(
self
,
kernel
,
matchable
):
return
True
return
True
def
update_persistent_hash
(
self
,
key_hash
,
key_builder
):
key_builder
.
rec
(
key_hash
,
"
all_match_expr
"
)
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
class
AndMatchExpression
(
MatchExpressionBase
):
class
AndMatchExpression
(
MatchExpressionBase
):
def
__init__
(
self
,
children
):
def
__init__
(
self
,
children
):
...
@@ -110,6 +120,14 @@ class AndMatchExpression(MatchExpressionBase):
...
@@ -110,6 +120,14 @@ class AndMatchExpression(MatchExpressionBase):
def
__str__
(
self
):
def
__str__
(
self
):
return
"
(%s)
"
%
(
"
and
"
.
join
(
str
(
ch
)
for
ch
in
self
.
children
))
return
"
(%s)
"
%
(
"
and
"
.
join
(
str
(
ch
)
for
ch
in
self
.
children
))
def
update_persistent_hash
(
self
,
key_hash
,
key_builder
):
key_builder
.
rec
(
key_hash
,
"
and_match_expr
"
)
key_builder
.
rec
(
key_hash
,
self
.
children
)
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
and
self
.
children
==
other
.
children
)
class
OrMatchExpression
(
MatchExpressionBase
):
class
OrMatchExpression
(
MatchExpressionBase
):
def
__init__
(
self
,
children
):
def
__init__
(
self
,
children
):
...
@@ -121,6 +139,14 @@ class OrMatchExpression(MatchExpressionBase):
...
@@ -121,6 +139,14 @@ class OrMatchExpression(MatchExpressionBase):
def
__str__
(
self
):
def
__str__
(
self
):
return
"
(%s)
"
%
(
"
or
"
.
join
(
str
(
ch
)
for
ch
in
self
.
children
))
return
"
(%s)
"
%
(
"
or
"
.
join
(
str
(
ch
)
for
ch
in
self
.
children
))
def
update_persistent_hash
(
self
,
key_hash
,
key_builder
):
key_builder
.
rec
(
key_hash
,
"
or_match_expr
"
)
key_builder
.
rec
(
key_hash
,
self
.
children
)
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
and
self
.
children
==
other
.
children
)
class
NotMatchExpression
(
MatchExpressionBase
):
class
NotMatchExpression
(
MatchExpressionBase
):
def
__init__
(
self
,
child
):
def
__init__
(
self
,
child
):
...
@@ -132,6 +158,14 @@ class NotMatchExpression(MatchExpressionBase):
...
@@ -132,6 +158,14 @@ class NotMatchExpression(MatchExpressionBase):
def
__str__
(
self
):
def
__str__
(
self
):
return
"
(not %s)
"
%
str
(
self
.
child
)
return
"
(not %s)
"
%
str
(
self
.
child
)
def
update_persistent_hash
(
self
,
key_hash
,
key_builder
):
key_builder
.
rec
(
key_hash
,
"
not_match_expr
"
)
key_builder
.
rec
(
key_hash
,
self
.
child
)
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
and
self
.
child
==
other
.
child
)
class
GlobMatchExpressionBase
(
MatchExpressionBase
):
class
GlobMatchExpressionBase
(
MatchExpressionBase
):
def
__init__
(
self
,
glob
):
def
__init__
(
self
,
glob
):
...
@@ -146,6 +180,14 @@ class GlobMatchExpressionBase(MatchExpressionBase):
...
@@ -146,6 +180,14 @@ class GlobMatchExpressionBase(MatchExpressionBase):
descr
=
descr
[:
descr
.
find
(
"
Match
"
)]
descr
=
descr
[:
descr
.
find
(
"
Match
"
)]
return
descr
.
lower
()
+
"
:
"
+
self
.
glob
return
descr
.
lower
()
+
"
:
"
+
self
.
glob
def
update_persistent_hash
(
self
,
key_hash
,
key_builder
):
key_builder
.
rec
(
key_hash
,
type
(
self
).
__name__
)
key_builder
.
rec
(
key_hash
,
self
.
glob
)
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
and
self
.
glob
==
other
.
glob
)
class
IdMatchExpression
(
GlobMatchExpressionBase
):
class
IdMatchExpression
(
GlobMatchExpressionBase
):
def
__call__
(
self
,
kernel
,
matchable
):
def
__call__
(
self
,
kernel
,
matchable
):
...
@@ -284,18 +326,31 @@ def parse_match(expr_str):
...
@@ -284,18 +326,31 @@ def parse_match(expr_str):
# {{{ stack match objects
# {{{ stack match objects
class
StackMatchComponent
(
object
):
class
StackMatchComponent
(
object
):
pass
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
class
StackAllMatchComponent
(
StackMatchComponent
):
class
StackAllMatchComponent
(
StackMatchComponent
):
def
__call__
(
self
,
kernel
,
stack
):
def
__call__
(
self
,
kernel
,
stack
):
return
True
return
True
def
update_persistent_hash
(
self
,
key_hash
,
key_builder
):
key_builder
.
rec
(
key_hash
,
"
all_match
"
)
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
class
StackBottomMatchComponent
(
StackMatchComponent
):
class
StackBottomMatchComponent
(
StackMatchComponent
):
def
__call__
(
self
,
kernel
,
stack
):
def
__call__
(
self
,
kernel
,
stack
):
return
not
stack
return
not
stack
def
update_persistent_hash
(
self
,
key_hash
,
key_builder
):
key_builder
.
rec
(
key_hash
,
"
bottom_match
"
)
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
class
StackItemMatchComponent
(
StackMatchComponent
):
class
StackItemMatchComponent
(
StackMatchComponent
):
def
__init__
(
self
,
match_expr
,
inner_match
):
def
__init__
(
self
,
match_expr
,
inner_match
):
...
@@ -312,6 +367,16 @@ class StackItemMatchComponent(StackMatchComponent):
...
@@ -312,6 +367,16 @@ class StackItemMatchComponent(StackMatchComponent):
return
self
.
inner_match
(
kernel
,
stack
[
1
:])
return
self
.
inner_match
(
kernel
,
stack
[
1
:])
def
update_persistent_hash
(
self
,
key_hash
,
key_builder
):
key_builder
.
rec
(
key_hash
,
"
item_match
"
)
key_builder
.
rec
(
key_hash
,
self
.
match_expr
)
key_builder
.
rec
(
key_hash
,
self
.
inner_match
)
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
and
self
.
match_expr
==
other
.
match_expr
and
self
.
inner_match
==
other
.
inner_match
)
class
StackWildcardMatchComponent
(
StackMatchComponent
):
class
StackWildcardMatchComponent
(
StackMatchComponent
):
def
__init__
(
self
,
inner_match
):
def
__init__
(
self
,
inner_match
):
...
@@ -348,6 +413,18 @@ class StackMatch(object):
...
@@ -348,6 +413,18 @@ class StackMatch(object):
def
__init__
(
self
,
root_component
):
def
__init__
(
self
,
root_component
):
self
.
root_component
=
root_component
self
.
root_component
=
root_component
def
update_persistent_hash
(
self
,
key_hash
,
key_builder
):
key_builder
.
rec
(
key_hash
,
self
.
root_component
)
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
and
self
.
root_component
==
other
.
root_component
)
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
def
__call__
(
self
,
kernel
,
insn
,
rule_stack
):
def
__call__
(
self
,
kernel
,
insn
,
rule_stack
):
"""
"""
:arg rule_stack: a tuple of (name, tags) rule invocation, outermost first
:arg rule_stack: a tuple of (name, tags) rule invocation, outermost first
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment