Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
pymbolic
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Lawrence Mitchell
pymbolic
Commits
3b7bdf6c
Commit
3b7bdf6c
authored
13 years ago
by
Andreas Klöckner
Browse files
Options
Downloads
Patches
Plain Diff
Unifier: Add force_var_match.
parent
748fb495
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pymbolic/mapper/unifier.py
+42
-38
42 additions, 38 deletions
pymbolic/mapper/unifier.py
with
42 additions
and
38 deletions
pymbolic/mapper/unifier.py
+
42
−
38
View file @
3b7bdf6c
...
@@ -73,21 +73,29 @@ def unify_many(unis1, uni2):
...
@@ -73,21 +73,29 @@ def unify_many(unis1, uni2):
class
UnifierBase
(
RecursiveMapper
):
class
UnifierBase
(
RecursiveMapper
):
def
__init__
(
self
,
lhs_mapping_candidates
=
None
,
def
__init__
(
self
,
lhs_mapping_candidates
=
None
,
rhs_mapping_candidates
=
None
):
rhs_mapping_candidates
=
None
,
force_var_match
=
True
):
self
.
lhs_mapping_candidates
=
lhs_mapping_candidates
self
.
lhs_mapping_candidates
=
lhs_mapping_candidates
self
.
rhs_mapping_candidates
=
rhs_mapping_candidates
self
.
rhs_mapping_candidates
=
rhs_mapping_candidates
self
.
force_var_match
=
force_var_match
def
unification_record_from_equation
(
self
,
lhs
,
rhs
):
def
unification_record_from_equation
(
self
,
lhs
,
rhs
):
if
isinstance
(
lhs
,
(
tuple
,
list
))
or
isinstance
(
rhs
,
(
tuple
,
list
)):
if
isinstance
(
lhs
,
(
tuple
,
list
))
or
isinstance
(
rhs
,
(
tuple
,
list
)):
# must match elementwise!
# must match elementwise!
return
None
return
None
lhs_is_var
=
isinstance
(
lhs
,
Variable
)
rhs_is_var
=
isinstance
(
rhs
,
Variable
)
if
self
.
force_var_match
and
not
(
lhs_is_var
or
rhs_is_var
):
return
None
if
(
self
.
lhs_mapping_candidates
is
not
None
if
(
self
.
lhs_mapping_candidates
is
not
None
and
isinstance
(
lhs
,
Variable
)
and
lhs_is_var
and
lhs
.
name
not
in
self
.
lhs_mapping_candidates
):
and
lhs
.
name
not
in
self
.
lhs_mapping_candidates
):
return
None
return
None
if
(
self
.
rhs_mapping_candidates
is
not
None
if
(
self
.
rhs_mapping_candidates
is
not
None
and
isinstance
(
rhs
,
Variable
)
and
rhs_is_var
and
rhs
.
name
not
in
self
.
rhs_mapping_candidates
):
and
rhs
.
name
not
in
self
.
rhs_mapping_candidates
):
return
None
return
None
...
@@ -122,17 +130,46 @@ class UnifierBase(RecursiveMapper):
...
@@ -122,17 +130,46 @@ class UnifierBase(RecursiveMapper):
def
map_lookup
(
self
,
expr
,
other
,
unis
):
def
map_lookup
(
self
,
expr
,
other
,
unis
):
if
not
isinstance
(
other
,
type
(
expr
)):
if
not
isinstance
(
other
,
type
(
expr
)):
return
self
.
treat_mismatch
(
expr
,
other
,
unis
)
return
self
.
treat_mismatch
(
expr
,
other
,
unis
)
if
self
.
name
!=
other
.
name
:
if
expr
.
name
!=
other
.
name
:
return
[]
return
[]
return
self
.
rec
(
expr
.
aggregate
,
other
.
aggregate
,
unis
)
return
self
.
rec
(
expr
.
aggregate
,
other
.
aggregate
,
unis
)
def
map_sum
(
self
,
expr
,
other
,
unis
):
if
(
not
isinstance
(
other
,
type
(
expr
))
or
len
(
expr
.
children
)
!=
len
(
other
.
children
)):
return
[]
result
=
[]
from
pytools
import
generate_permutations
had_structural_match
=
False
for
perm
in
generate_permutations
(
range
(
len
(
expr
.
children
))):
it_assignments
=
unis
for
my_child
,
other_child
in
zip
(
expr
.
children
,
(
other
.
children
[
i
]
for
i
in
perm
)):
it_assignments
=
self
.
rec
(
my_child
,
other_child
,
it_assignments
)
if
not
it_assignments
:
break
if
it_assignments
:
had_structural_match
=
True
result
.
extend
(
it_assignments
)
if
not
had_structural_match
:
return
self
.
treat_mismatch
(
expr
,
other
,
unis
)
return
result
map_product
=
map_sum
def
map_negation
(
self
,
expr
,
other
,
unis
):
def
map_negation
(
self
,
expr
,
other
,
unis
):
if
not
isinstance
(
other
,
type
(
expr
)):
if
not
isinstance
(
other
,
type
(
expr
)):
return
self
.
treat_mismatch
(
expr
,
other
,
unis
)
return
self
.
treat_mismatch
(
expr
,
other
,
unis
)
return
self
.
rec
(
expr
.
child
,
other
.
child
,
unis
)
return
self
.
rec
(
expr
.
child
,
other
.
child
,
unis
)
def
map_quotient
(
self
,
expr
,
other
,
unis
):
def
map_quotient
(
self
,
expr
,
other
,
unis
):
if
not
isinstance
(
other
,
type
(
expr
)):
if
not
isinstance
(
other
,
type
(
expr
)):
return
self
.
treat_mismatch
(
expr
,
other
,
unis
)
return
self
.
treat_mismatch
(
expr
,
other
,
unis
)
...
@@ -162,7 +199,6 @@ class UnifierBase(RecursiveMapper):
...
@@ -162,7 +199,6 @@ class UnifierBase(RecursiveMapper):
return
unis
return
unis
map_tuple
=
map_list
map_tuple
=
map_list
def
__call__
(
self
,
expr
,
other
,
unis
=
None
):
def
__call__
(
self
,
expr
,
other
,
unis
=
None
):
...
@@ -178,36 +214,6 @@ class UnidirectionalUnifier(UnifierBase):
...
@@ -178,36 +214,6 @@ class UnidirectionalUnifier(UnifierBase):
subexpression of the second.
subexpression of the second.
"""
"""
def
map_sum
(
self
,
expr
,
other
,
unis
):
if
(
not
isinstance
(
other
,
type
(
expr
))
or
len
(
expr
.
children
)
!=
len
(
other
.
children
)):
return
[]
result
=
[]
from
pytools
import
generate_permutations
had_structural_match
=
False
for
perm
in
generate_permutations
(
range
(
len
(
expr
.
children
))):
it_assignments
=
unis
for
my_child
,
other_child
in
zip
(
expr
.
children
,
(
other
.
children
[
i
]
for
i
in
perm
)):
it_assignments
=
self
.
rec
(
my_child
,
other_child
,
it_assignments
)
if
not
it_assignments
:
break
if
it_assignments
:
had_structural_match
=
True
result
.
extend
(
it_assignments
)
if
not
had_structural_match
:
return
self
.
treat_mismatch
(
expr
,
other
,
unis
)
return
result
map_product
=
map_sum
def
treat_mismatch
(
self
,
expr
,
other
,
unis
):
def
treat_mismatch
(
self
,
expr
,
other
,
unis
):
return
[]
return
[]
...
@@ -219,5 +225,3 @@ class BidirectionalUnifier(UnifierBase):
...
@@ -219,5 +225,3 @@ class BidirectionalUnifier(UnifierBase):
"""
"""
treat_mismatch
=
UnifierBase
.
map_variable
treat_mismatch
=
UnifierBase
.
map_variable
map_sum
=
UnifierBase
.
map_variable
map_product
=
UnifierBase
.
map_variable
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment