Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2d46d60e
提交
2d46d60e
authored
8月 27, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
8月 28, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove unused rewrites and functionality
上级
40ccab1a
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
46 行增加
和
518 行删除
+46
-518
graph_rewriting.rst
doc/extending/graph_rewriting.rst
+1
-1
features.rst
doc/library/graph/features.rst
+0
-4
features.py
pytensor/graph/features.py
+0
-94
basic.py
pytensor/graph/rewriting/basic.py
+6
-268
basic.py
pytensor/tensor/rewriting/basic.py
+4
-2
math.py
pytensor/tensor/rewriting/math.py
+0
-9
test_types.py
tests/compile/function/test_types.py
+2
-2
test_basic.py
tests/graph/rewriting/test_basic.py
+30
-54
test_destroyhandler.py
tests/graph/test_destroyhandler.py
+1
-2
test_features.py
tests/graph/test_features.py
+2
-82
没有找到文件。
doc/extending/graph_rewriting.rst
浏览文件 @
2d46d60e
...
...
@@ -134,7 +134,7 @@ computation graph.
In a nutshell, :class:`ReplaceValidate` grants access to :meth:`fgraph.replace_validate`,
and :meth:`fgraph.replace_validate` allows us to replace a :class:`Variable` with
another while respecting certain validation constraints. As an
exercise, try to rewrite :class:`Simplify` using :class:`
NodeFind
er`. (Hint: you
exercise, try to rewrite :class:`Simplify` using :class:`
WalkingGraphRewrit
er`. (Hint: you
want to use the method it publishes instead of the call to toposort)
Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by
...
...
doc/library/graph/features.rst
浏览文件 @
2d46d60e
...
...
@@ -26,7 +26,3 @@ Guide
.. class:: ReplaceValidate(History, Validator)
.. method:: replace_validate(fgraph, var, new_var, reason=None)
.. class:: NodeFinder(Bookkeeper)
.. class:: PrintListener(object)
pytensor/graph/features.py
浏览文件 @
2d46d60e
...
...
@@ -827,100 +827,6 @@ class ReplaceValidate(History, Validator):
raise
InconsistencyError
(
"Trying to reintroduce a removed node"
)
class
NodeFinder
(
Bookkeeper
):
def
__init__
(
self
):
self
.
fgraph
=
None
self
.
d
=
{}
def
on_attach
(
self
,
fgraph
):
if
hasattr
(
fgraph
,
"get_nodes"
):
raise
AlreadyThere
(
"NodeFinder is already present"
)
if
self
.
fgraph
is
not
None
and
self
.
fgraph
!=
fgraph
:
raise
Exception
(
"A NodeFinder instance can only serve one FunctionGraph."
)
self
.
fgraph
=
fgraph
fgraph
.
get_nodes
=
partial
(
self
.
query
,
fgraph
)
Bookkeeper
.
on_attach
(
self
,
fgraph
)
def
clone
(
self
):
return
type
(
self
)()
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if
self
.
fgraph
is
not
fgraph
:
raise
Exception
(
"This NodeFinder instance was not attached to the provided fgraph."
)
self
.
fgraph
=
None
del
fgraph
.
get_nodes
Bookkeeper
.
on_detach
(
self
,
fgraph
)
def
on_import
(
self
,
fgraph
,
node
,
reason
):
try
:
self
.
d
.
setdefault
(
node
.
op
,
[])
.
append
(
node
)
except
TypeError
:
# node.op is unhashable
return
except
Exception
as
e
:
print
(
"OFFENDING node"
,
type
(
node
),
type
(
node
.
op
),
file
=
sys
.
stderr
)
# noqa: T201
try
:
print
(
"OFFENDING node hash"
,
hash
(
node
.
op
),
file
=
sys
.
stderr
)
# noqa: T201
except
Exception
:
print
(
"OFFENDING node not hashable"
,
file
=
sys
.
stderr
)
# noqa: T201
raise
e
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
try
:
nodes
=
self
.
d
[
node
.
op
]
except
TypeError
:
# node.op is unhashable
return
nodes
.
remove
(
node
)
if
not
nodes
:
del
self
.
d
[
node
.
op
]
def
query
(
self
,
fgraph
,
op
):
try
:
all
=
self
.
d
.
get
(
op
,
[])
except
TypeError
:
raise
TypeError
(
f
"{op} in unhashable and cannot be queried by the optimizer"
)
all
=
list
(
all
)
return
all
class
PrintListener
(
Feature
):
def
__init__
(
self
,
active
=
True
):
self
.
active
=
active
def
on_attach
(
self
,
fgraph
):
if
self
.
active
:
print
(
"-- attaching to: "
,
fgraph
)
# noqa: T201
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if
self
.
active
:
print
(
"-- detaching from: "
,
fgraph
)
# noqa: T201
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
print
(
f
"-- importing: {node}, reason: {reason}"
)
# noqa: T201
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
print
(
f
"-- pruning: {node}, reason: {reason}"
)
# noqa: T201
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
if
self
.
active
:
print
(
f
"-- changing ({node}.inputs[{i}]) from {r} to {new_r}"
)
# noqa: T201
class
PreserveVariableAttributes
(
Feature
):
"""
This preserve some variables attributes and tag during optimization.
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
2d46d60e
...
...
@@ -11,8 +11,7 @@ import traceback
import
warnings
from
collections
import
Counter
,
UserList
,
defaultdict
,
deque
from
collections.abc
import
Callable
,
Iterable
,
Sequence
from
collections.abc
import
Iterable
as
IterableType
from
functools
import
_compose_mro
,
partial
,
reduce
# type: ignore
from
functools
import
_compose_mro
,
partial
# type: ignore
from
itertools
import
chain
from
typing
import
TYPE_CHECKING
,
Literal
...
...
@@ -28,7 +27,7 @@ from pytensor.graph.basic import (
io_toposort
,
vars_between
,
)
from
pytensor.graph.features
import
AlreadyThere
,
Feature
,
NodeFinder
from
pytensor.graph.features
import
AlreadyThere
,
Feature
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.utils
import
AssocList
,
InconsistencyError
...
...
@@ -60,14 +59,6 @@ FailureCallbackType = Callable[
]
class
MetaNodeRewriterSkip
(
AssertionError
):
"""This is an `AssertionError`, but instead of having the
`MetaNodeRewriter` print the error, it just skip that
compilation.
"""
class
Rewriter
(
abc
.
ABC
):
"""Abstract base class for graph/term rewriters."""
...
...
@@ -942,129 +933,6 @@ def pre_constant_merge(fgraph, variables):
return
[
recursive_merge
(
v
)
for
v
in
variables
]
class
MetaNodeRewriter
(
NodeRewriter
):
r"""
Base class for meta-rewriters that try a set of `NodeRewriter`\s
to replace a node and choose the one that executes the fastest.
If the error `MetaNodeRewriterSkip` is raised during
compilation, we will skip that function compilation and not print
the error.
"""
def
__init__
(
self
):
self
.
verbose
=
config
.
metaopt__verbose
self
.
track_dict
=
defaultdict
(
list
)
self
.
tag_dict
=
defaultdict
(
list
)
self
.
_tracks
=
[]
self
.
rewriters
=
[]
def
register
(
self
,
rewriter
:
NodeRewriter
,
tag_list
:
IterableType
[
str
]):
self
.
rewriters
.
append
(
rewriter
)
tracks
=
rewriter
.
tracks
()
if
tracks
:
self
.
_tracks
.
extend
(
tracks
)
for
c
in
tracks
:
self
.
track_dict
[
c
]
.
append
(
rewriter
)
for
tag
in
tag_list
:
self
.
tag_dict
[
tag
]
.
append
(
rewriter
)
def
tracks
(
self
):
return
self
.
_tracks
def
transform
(
self
,
fgraph
,
node
,
*
args
,
**
kwargs
):
# safety check: depending on registration, tracks may have been ignored
if
self
.
_tracks
is
not
None
:
if
not
isinstance
(
node
.
op
,
tuple
(
self
.
_tracks
)):
return
# first, we need to provide dummy values for all inputs
# to the node that are not shared variables anyway
givens
=
{}
missing
=
set
()
for
input
in
node
.
inputs
:
if
isinstance
(
input
,
pytensor
.
compile
.
SharedVariable
):
pass
elif
hasattr
(
input
.
tag
,
"test_value"
):
givens
[
input
]
=
pytensor
.
shared
(
input
.
type
.
filter
(
input
.
tag
.
test_value
),
input
.
name
,
shape
=
input
.
broadcastable
,
borrow
=
True
,
)
else
:
missing
.
add
(
input
)
if
missing
:
givens
.
update
(
self
.
provide_inputs
(
node
,
missing
))
missing
.
difference_update
(
givens
.
keys
())
# ensure we have data for all input variables that need it
if
missing
:
if
self
.
verbose
>
0
:
print
(
# noqa: T201
f
"{self.__class__.__name__} cannot meta-rewrite {node}, "
f
"{len(missing)} of {int(node.nin)} input shapes unknown"
)
return
# now we can apply the different rewrites in turn,
# compile the resulting subgraphs and time their execution
if
self
.
verbose
>
1
:
print
(
# noqa: T201
f
"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):"
)
timings
=
[]
for
node_rewriter
in
self
.
get_rewrites
(
node
):
outputs
=
node_rewriter
.
transform
(
fgraph
,
node
,
*
args
,
**
kwargs
)
if
outputs
:
try
:
fn
=
pytensor
.
function
(
[],
outputs
,
givens
=
givens
,
on_unused_input
=
"ignore"
)
fn
.
trust_input
=
True
timing
=
min
(
self
.
time_call
(
fn
)
for
_
in
range
(
2
))
except
MetaNodeRewriterSkip
:
continue
except
Exception
as
e
:
if
self
.
verbose
>
0
:
print
(
f
"* {node_rewriter}: exception"
,
e
)
# noqa: T201
continue
else
:
if
self
.
verbose
>
1
:
print
(
f
"* {node_rewriter}: {timing:.5g} sec"
)
# noqa: T201
timings
.
append
((
timing
,
outputs
,
node_rewriter
))
else
:
if
self
.
verbose
>
0
:
print
(
f
"* {node_rewriter}: not applicable"
)
# noqa: T201
# finally, we choose the fastest one
if
timings
:
timings
.
sort
()
if
self
.
verbose
>
1
:
print
(
f
"= {timings[0][2]}"
)
# noqa: T201
return
timings
[
0
][
1
]
return
def
provide_inputs
(
self
,
node
,
inputs
):
"""Return a dictionary mapping some `inputs` to `SharedVariable` instances of with dummy values.
The `node` argument can be inspected to infer required input shapes.
"""
raise
NotImplementedError
()
def
get_rewrites
(
self
,
node
):
"""Return the rewrites that apply to `node`.
This uses ``self.track_dict[type(node.op)]`` by default.
"""
return
self
.
track_dict
[
type
(
node
.
op
)]
def
time_call
(
self
,
fn
):
start
=
time
.
perf_counter
()
fn
()
return
time
.
perf_counter
()
-
start
class
FromFunctionNodeRewriter
(
NodeRewriter
):
"""A `NodeRewriter` constructed from a function."""
...
...
@@ -1214,9 +1082,6 @@ class SequentialNodeRewriter(NodeRewriter):
reentrant : bool
Some global rewriters, like `NodeProcessingGraphRewriter`, use this value to
determine if they should ignore new nodes.
retains_inputs : bool
States whether or not the inputs of a transformed node are transferred
to the outputs.
"""
def
__init__
(
...
...
@@ -1247,9 +1112,6 @@ class SequentialNodeRewriter(NodeRewriter):
self
.
reentrant
=
any
(
getattr
(
rewrite
,
"reentrant"
,
True
)
for
rewrite
in
rewriters
)
self
.
retains_inputs
=
all
(
getattr
(
rewrite
,
"retains_inputs"
,
False
)
for
rewrite
in
rewriters
)
self
.
apply_all_rewrites
=
apply_all_rewrites
...
...
@@ -1425,17 +1287,12 @@ class SubstitutionNodeRewriter(NodeRewriter):
# an SubstitutionNodeRewriter does not apply to the nodes it produces
reentrant
=
False
# all the inputs of the original node are transferred to the outputs
retains_inputs
=
True
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
self
.
op1
=
op1
self
.
op2
=
op2
self
.
transfer_tags
=
transfer_tags
def
op_key
(
self
):
return
self
.
op1
def
tracks
(
self
):
return
[
self
.
op1
]
...
...
@@ -1453,39 +1310,6 @@ class SubstitutionNodeRewriter(NodeRewriter):
return
f
"{self.op1} -> {self.op2}"
class
RemovalNodeRewriter
(
NodeRewriter
):
"""
Removes all applications of an `Op` by transferring each of its
outputs to the corresponding input.
"""
reentrant
=
False
# no nodes are added at all
def
__init__
(
self
,
op
):
self
.
op
=
op
def
op_key
(
self
):
return
self
.
op
def
tracks
(
self
):
return
[
self
.
op
]
def
transform
(
self
,
fgraph
,
node
):
if
node
.
op
!=
self
.
op
:
return
False
return
node
.
inputs
def
__str__
(
self
):
return
f
"{self.op}(x) -> x"
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
(
f
"{' ' * level}{self.__class__.__name__}(self.op) id={id(self)}"
,
file
=
stream
,
)
class
PatternNodeRewriter
(
NodeRewriter
):
"""Replace all occurrences of an input pattern with an output pattern.
...
...
@@ -1545,7 +1369,6 @@ class PatternNodeRewriter(NodeRewriter):
in_pattern
,
out_pattern
,
allow_multiple_clients
:
bool
=
False
,
skip_identities_fn
=
None
,
name
:
str
|
None
=
None
,
tracks
=
(),
get_nodes
=
None
,
...
...
@@ -1563,8 +1386,6 @@ class PatternNodeRewriter(NodeRewriter):
allow_multiple_clients
If ``False``, the pattern matching will fail if one of the subpatterns has
more than one client.
skip_identities_fn
TODO
name
Set the name of this rewriter.
tracks
...
...
@@ -1574,15 +1395,15 @@ class PatternNodeRewriter(NodeRewriter):
function that takes the tracked node and returns a list of nodes on
which we will try this rewrite.
values_eq_approx
TODO
If specified, this value will be assigned to the ``values_eq_approx``
tag of the output variable. This is used by DebugMode to determine if rewrites are correct.
allow_cast
Automatically cast the output of the rewrite whenever new and old types differ
Notes
-----
`tracks` and `get_nodes` can be used to make this rewrite track a less
frequent `Op`, which will prevent the rewrite from being tried as
often.
frequent `Op`, which will prevent the rewrite from being tried as often.
"""
from
pytensor.graph.rewriting.unify
import
convert_strs_to_vars
...
...
@@ -1600,9 +1421,7 @@ class PatternNodeRewriter(NodeRewriter):
raise
TypeError
(
"The pattern to search for must start with a specific Op instance."
)
self
.
__doc__
=
f
"{self.__class__.__doc__}
\n\n
This instance does: {self}
\n
"
self
.
allow_multiple_clients
=
allow_multiple_clients
self
.
skip_identities_fn
=
skip_identities_fn
if
name
:
self
.
__name__
=
name
self
.
_tracks
=
tracks
...
...
@@ -1610,9 +1429,6 @@ class PatternNodeRewriter(NodeRewriter):
if
tracks
!=
():
assert
get_nodes
def
op_key
(
self
):
return
self
.
op
def
tracks
(
self
):
if
self
.
_tracks
!=
():
return
self
.
_tracks
...
...
@@ -2136,7 +1952,7 @@ def walking_rewriter(
else
:
(
node_rewriters
,)
=
node_rewriters
if
not
name
:
name
=
node_rewriters
.
__name__
name
=
getattr
(
node_rewriters
,
"__name__"
,
None
)
ret
=
WalkingGraphRewriter
(
node_rewriters
,
order
=
order
,
...
...
@@ -2152,52 +1968,6 @@ in2out = partial(walking_rewriter, "in_to_out")
out2in
=
partial
(
walking_rewriter
,
"out_to_in"
)
class
OpKeyGraphRewriter
(
NodeProcessingGraphRewriter
):
r"""A rewriter that applies a `NodeRewriter` to specific `Op`\s.
The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either
as a list of `Op`\s or a single `Op`), and discovered within a
`FunctionGraph` using the `NodeFinder` `Feature`.
This is similar to the `Op`-based tracking feature used by other rewriters.
"""
def
__init__
(
self
,
node_rewriter
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
not
hasattr
(
node_rewriter
,
"op_key"
):
raise
TypeError
(
f
"{node_rewriter} must have an `op_key` method."
)
super
()
.
__init__
(
node_rewriter
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
fgraph
):
op
=
self
.
node_rewriter
.
op_key
()
if
isinstance
(
op
,
list
|
tuple
):
q
=
reduce
(
list
.
__iadd__
,
map
(
fgraph
.
get_nodes
,
op
))
else
:
q
=
list
(
fgraph
.
get_nodes
(
op
))
def
importer
(
node
):
if
node
is
not
current_node
:
if
node
.
op
==
op
:
q
.
append
(
node
)
u
=
self
.
attach_updater
(
fgraph
,
importer
,
None
,
name
=
getattr
(
self
,
"name"
,
None
)
)
try
:
while
q
:
node
=
q
.
pop
()
if
node
not
in
fgraph
.
apply_nodes
:
continue
current_node
=
node
self
.
process_node
(
fgraph
,
node
)
finally
:
self
.
detach_updater
(
fgraph
,
u
)
def
add_requirements
(
self
,
fgraph
):
super
()
.
add_requirements
(
fgraph
)
fgraph
.
attach_feature
(
NodeFinder
())
class
ChangeTracker
(
Feature
):
def
__init__
(
self
):
self
.
changed
=
False
...
...
@@ -2785,38 +2555,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
)
def
_check_chain
(
r
,
chain
):
"""
WRITEME
"""
chain
=
list
(
reversed
(
chain
))
while
chain
:
elem
=
chain
.
pop
()
if
elem
is
None
:
if
r
.
owner
is
not
None
:
return
False
elif
r
.
owner
is
None
:
return
False
elif
isinstance
(
elem
,
Op
):
if
r
.
owner
.
op
!=
elem
:
return
False
else
:
try
:
if
issubclass
(
elem
,
Op
)
and
not
isinstance
(
r
.
owner
.
op
,
elem
):
return
False
except
TypeError
:
return
False
if
chain
:
r
=
r
.
owner
.
inputs
[
chain
.
pop
()]
# print 'check_chain', _check_chain.n_calls
# _check_chain.n_calls += 1
# The return value will be used as a Boolean, but some Variables cannot
# be used as Booleans (the results of comparisons, for instance)
return
r
is
not
None
def
pre_greedy_node_rewriter
(
fgraph
:
FunctionGraph
,
rewrites
:
Sequence
[
NodeRewriter
],
out
:
Variable
)
->
Variable
:
...
...
pytensor/tensor/rewriting/basic.py
浏览文件 @
2d46d60e
...
...
@@ -34,7 +34,6 @@ from pytensor.graph.basic import Constant
from
pytensor.graph.rewriting.basic
import
(
NodeProcessingGraphRewriter
,
NodeRewriter
,
RemovalNodeRewriter
,
Rewriter
,
copy_stack_trace
,
in2out
,
...
...
@@ -1224,7 +1223,10 @@ def local_merge_alloc(fgraph, node):
return
[
alloc
(
inputs_inner
[
0
],
*
dims_outer
)]
register_canonicalize
(
RemovalNodeRewriter
(
tensor_copy
),
name
=
"remove_tensor_copy"
)
@register_canonicalize
@node_rewriter
(
tracks
=
[
tensor_copy
])
def
remove_tensor_copy
(
fgraph
,
node
):
return
node
.
inputs
@register_specialize
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
2d46d60e
...
...
@@ -3162,13 +3162,6 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
return
np
.
allclose
(
x
,
ref
,
rtol
=
rtol
,
atol
=
atol
)
def
_skip_mul_1
(
r
):
if
r
.
owner
and
r
.
owner
.
op
==
mul
:
not_is_1
=
[
i
for
i
in
r
.
owner
.
inputs
if
not
_is_1
(
i
)]
if
len
(
not_is_1
)
==
1
:
return
not_is_1
[
0
]
def
_is_1
(
expr
):
"""
...
...
@@ -3190,7 +3183,6 @@ logsigm_to_softplus = PatternNodeRewriter(
(
neg
,
(
softplus
,
(
neg
,
"x"
))),
allow_multiple_clients
=
True
,
values_eq_approx
=
values_eq_approx_remove_inf
,
skip_identities_fn
=
_skip_mul_1
,
tracks
=
[
sigmoid
],
get_nodes
=
get_clients_at_depth1
,
)
...
...
@@ -3199,7 +3191,6 @@ log1msigm_to_softplus = PatternNodeRewriter(
(
neg
,
(
softplus
,
"x"
)),
allow_multiple_clients
=
True
,
values_eq_approx
=
values_eq_approx_remove_inf
,
skip_identities_fn
=
_skip_mul_1
,
tracks
=
[
sigmoid
],
get_nodes
=
get_clients_at_depth2
,
)
...
...
tests/compile/function/test_types.py
浏览文件 @
2d46d60e
...
...
@@ -13,7 +13,7 @@ from pytensor.compile.io import In, Out
from
pytensor.compile.mode
import
Mode
,
get_default_mode
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.rewriting.basic
import
OpKeyGraphRewriter
,
PatternNode
Rewriter
from
pytensor.graph.rewriting.basic
import
PatternNodeRewriter
,
WalkingGraph
Rewriter
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.link.vm
import
VMLinker
from
pytensor.printing
import
debugprint
...
...
@@ -39,7 +39,7 @@ pytestmark = pytest.mark.filterwarnings("error")
def
PatternOptimizer
(
p1
,
p2
,
ign
=
True
):
return
OpKey
GraphRewriter
(
PatternNodeRewriter
(
p1
,
p2
),
ignore_newtrees
=
ign
)
return
Walking
GraphRewriter
(
PatternNodeRewriter
(
p1
,
p2
),
ignore_newtrees
=
ign
)
class
TestFunction
:
...
...
tests/graph/rewriting/test_basic.py
浏览文件 @
2d46d60e
...
...
@@ -8,11 +8,9 @@ from pytensor.graph.op import Op
from
pytensor.graph.rewriting.basic
import
(
EquilibriumGraphRewriter
,
MergeOptimizer
,
OpKeyGraphRewriter
,
OpToRewriterTracker
,
PatternNodeRewriter
,
SequentialNodeRewriter
,
SubstitutionNodeRewriter
,
WalkingGraphRewriter
,
in2out
,
logging
,
...
...
@@ -51,33 +49,29 @@ class AssertNoChanges(Feature):
raise
AssertionError
()
def
OpKey
PatternNodeRewriter
(
p1
,
p2
,
allow_multiple_clients
=
False
,
ign
=
False
):
return
OpKey
GraphRewriter
(
def
Walking
PatternNodeRewriter
(
p1
,
p2
,
allow_multiple_clients
=
False
,
ign
=
False
):
return
Walking
GraphRewriter
(
PatternNodeRewriter
(
p1
,
p2
,
allow_multiple_clients
=
allow_multiple_clients
),
ignore_newtrees
=
ign
,
)
def
WalkingPatternNodeRewriter
(
p1
,
p2
,
ign
=
True
):
return
WalkingGraphRewriter
(
PatternNodeRewriter
(
p1
,
p2
),
ignore_newtrees
=
ign
)
class
TestPatternNodeRewriter
:
def
test_replace_output
(
self
):
# replacing the whole graph
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op2
(
x
,
y
),
z
)
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKeyPatternNodeRewriter
((
op1
,
(
op2
,
"1"
,
"2"
),
"3"
),
(
op4
,
"3"
,
"2"
))
.
rewrite
(
g
)
WalkingPatternNodeRewriter
(
(
op1
,
(
op2
,
"1"
,
"2"
),
"3"
),
(
op4
,
"3"
,
"2"
)
)
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op4(z, y))"
def
test_nested_out_pattern
(
self
):
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
x
,
y
)
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
(
op1
,
"1"
,
"2"
),
(
op4
,
(
op1
,
"1"
),
(
op2
,
"2"
),
(
op3
,
"1"
,
"2"
))
)
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))"
...
...
@@ -86,7 +80,7 @@ class TestPatternNodeRewriter:
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op2
(
x
,
x
),
z
)
# the arguments to op2 are the same
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
(
op1
,
(
op2
,
"1"
,
"1"
),
"2"
),
# they are the same in the pattern
(
op4
,
"2"
,
"1"
),
)
.
rewrite
(
g
)
...
...
@@ -97,7 +91,7 @@ class TestPatternNodeRewriter:
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op2
(
x
,
y
),
z
)
# the arguments to op2 are different
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
(
op1
,
(
op2
,
"1"
,
"1"
),
"2"
),
# they are the same in the pattern
(
op4
,
"2"
,
"1"
),
)
.
rewrite
(
g
)
...
...
@@ -109,7 +103,7 @@ class TestPatternNodeRewriter:
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op2
(
x
,
y
),
z
)
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKey
PatternNodeRewriter
((
op2
,
"1"
,
"2"
),
(
op1
,
"2"
,
"1"
))
.
rewrite
(
g
)
Walking
PatternNodeRewriter
((
op2
,
"1"
,
"2"
),
(
op1
,
"2"
,
"1"
))
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op1(Op1(y, x), z))"
def
test_no_recurse
(
self
):
...
...
@@ -119,7 +113,9 @@ class TestPatternNodeRewriter:
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op2
(
x
,
y
),
z
)
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKeyPatternNodeRewriter
((
op2
,
"1"
,
"2"
),
(
op2
,
"2"
,
"1"
),
ign
=
True
)
.
rewrite
(
g
)
WalkingPatternNodeRewriter
((
op2
,
"1"
,
"2"
),
(
op2
,
"2"
,
"1"
),
ign
=
True
)
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op1(Op2(y, x), z))"
def
test_multiple
(
self
):
...
...
@@ -127,7 +123,7 @@ class TestPatternNodeRewriter:
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op2
(
x
,
y
),
op2
(
x
,
y
),
op2
(
y
,
z
))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKey
PatternNodeRewriter
((
op2
,
"1"
,
"2"
),
(
op4
,
"1"
))
.
rewrite
(
g
)
Walking
PatternNodeRewriter
((
op2
,
"1"
,
"2"
),
(
op4
,
"1"
))
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))"
def
test_nested_even
(
self
):
...
...
@@ -136,21 +132,21 @@ class TestPatternNodeRewriter:
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op1
(
op1
(
op1
(
x
))))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKey
PatternNodeRewriter
((
op1
,
(
op1
,
"1"
)),
"1"
)
.
rewrite
(
g
)
Walking
PatternNodeRewriter
((
op1
,
(
op1
,
"1"
)),
"1"
)
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(x)"
def
test_nested_odd
(
self
):
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op1
(
op1
(
op1
(
op1
(
x
)))))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKey
PatternNodeRewriter
((
op1
,
(
op1
,
"1"
)),
"1"
)
.
rewrite
(
g
)
Walking
PatternNodeRewriter
((
op1
,
(
op1
,
"1"
)),
"1"
)
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op1(x))"
def
test_expand
(
self
):
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op1
(
op1
(
x
)))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKey
PatternNodeRewriter
((
op1
,
"1"
),
(
op2
,
(
op1
,
"1"
)),
ign
=
True
)
.
rewrite
(
g
)
Walking
PatternNodeRewriter
((
op1
,
"1"
),
(
op2
,
(
op1
,
"1"
)),
ign
=
True
)
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))"
def
test_ambiguous
(
self
):
...
...
@@ -169,7 +165,7 @@ class TestPatternNodeRewriter:
z
=
Constant
(
MyType
(),
2
,
name
=
"z"
)
e
=
op1
(
op1
(
x
,
y
),
y
)
g
=
FunctionGraph
([
y
],
[
e
])
OpKey
PatternNodeRewriter
((
op1
,
z
,
"1"
),
(
op2
,
"1"
,
z
))
.
rewrite
(
g
)
Walking
PatternNodeRewriter
((
op1
,
z
,
"1"
),
(
op2
,
"1"
,
z
))
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op1(Op2(y, z{2}), y))"
def
test_constraints
(
self
):
...
...
@@ -181,7 +177,7 @@ class TestPatternNodeRewriter:
# Only replacing if the input is an instance of Op2
return
r
.
owner
.
op
==
op2
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
(
op1
,
{
"pattern"
:
"1"
,
"constraint"
:
constraint
}),
(
op3
,
"1"
)
)
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))"
...
...
@@ -190,7 +186,7 @@ class TestPatternNodeRewriter:
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
x
,
x
)
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKey
PatternNodeRewriter
((
op1
,
"x"
,
"y"
),
(
op3
,
"x"
,
"y"
))
.
rewrite
(
g
)
Walking
PatternNodeRewriter
((
op1
,
"x"
,
"y"
),
(
op3
,
"x"
,
"y"
))
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op3(x, x))"
@pytest.mark.xfail
(
...
...
@@ -202,10 +198,10 @@ class TestPatternNodeRewriter:
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
def
constraint
(
r
):
# Only replacing if the input
is an instance of Op2
# Only replacing if the input
s are not identical
return
r
.
owner
.
inputs
[
0
]
is
not
r
.
owner
.
inputs
[
1
]
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
{
"pattern"
:
(
op1
,
"x"
,
"y"
),
"constraint"
:
constraint
},
(
op3
,
"x"
,
"y"
)
)
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
...
...
@@ -220,7 +216,7 @@ class TestPatternNodeRewriter:
# So the replacement should fail
outputs
=
[
e
]
g
=
FunctionGraph
(
inputs
,
outputs
,
copy_inputs
=
False
)
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
(
op4
,
(
op1
,
"x"
,
"y"
)),
(
op3
,
"x"
,
"y"
),
)
.
rewrite
(
g
)
...
...
@@ -228,7 +224,7 @@ class TestPatternNodeRewriter:
# Now it should be fine
g
=
FunctionGraph
(
inputs
,
outputs
,
copy_inputs
=
False
)
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
(
op4
,
(
op1
,
"x"
,
"y"
)),
(
op3
,
"x"
,
"y"
),
allow_multiple_clients
=
True
,
...
...
@@ -237,7 +233,7 @@ class TestPatternNodeRewriter:
# The fact that the inputs of the pattern have multiple clients should not matter
g
=
FunctionGraph
(
inputs
,
outputs
,
copy_inputs
=
False
)
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
(
op3
,
(
op4
,
"w"
),
"w"
),
(
op3
,
"w"
,
"w"
),
allow_multiple_clients
=
False
,
...
...
@@ -252,7 +248,7 @@ class TestPatternNodeRewriter:
outputs
=
[
e1
,
e2
]
g
=
FunctionGraph
(
inputs
,
outputs
,
copy_inputs
=
False
)
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
(
op4
,
(
op4
,
"e"
)),
"e"
,
allow_multiple_clients
=
False
,
...
...
@@ -261,7 +257,7 @@ class TestPatternNodeRewriter:
outputs
=
[
e1
,
e3
]
g
=
FunctionGraph
([
x
,
y
,
z
],
outputs
,
copy_inputs
=
False
)
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
(
op4
,
(
op4
,
"e"
)),
"e"
,
allow_multiple_clients
=
False
,
...
...
@@ -269,7 +265,7 @@ class TestPatternNodeRewriter:
assert
equal_computations
(
g
.
outputs
,
outputs
)
g
=
FunctionGraph
(
inputs
,
outputs
,
copy_inputs
=
False
)
OpKey
PatternNodeRewriter
(
Walking
PatternNodeRewriter
(
(
op4
,
(
op4
,
"e"
)),
"e"
,
allow_multiple_clients
=
True
,
...
...
@@ -281,33 +277,13 @@ class TestPatternNodeRewriter:
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op_y
(
x
,
y
),
z
)
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
OpKeyPatternNodeRewriter
((
op1
,
(
op_z
,
"1"
,
"2"
),
"3"
),
(
op4
,
"3"
,
"2"
))
.
rewrite
(
g
)
WalkingPatternNodeRewriter
(
(
op1
,
(
op_z
,
"1"
,
"2"
),
"3"
),
(
op4
,
"3"
,
"2"
)
)
.
rewrite
(
g
)
str_g
=
str
(
g
)
assert
str_g
==
"FunctionGraph(Op4(z, y))"
def
KeyedSubstitutionNodeRewriter
(
op1
,
op2
):
return
OpKeyGraphRewriter
(
SubstitutionNodeRewriter
(
op1
,
op2
))
class
TestSubstitutionNodeRewriter
:
def
test_straightforward
(
self
):
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op1
(
op1
(
op1
(
op1
(
x
)))))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
KeyedSubstitutionNodeRewriter
(
op1
,
op2
)
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))"
def
test_straightforward_2
(
self
):
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op2
(
x
),
op3
(
y
),
op4
(
z
))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
KeyedSubstitutionNodeRewriter
(
op3
,
op4
)
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"
class
NoInputOp
(
Op
):
__props__
=
(
"param"
,)
...
...
tests/graph/test_destroyhandler.py
浏览文件 @
2d46d60e
...
...
@@ -10,7 +10,6 @@ from pytensor.graph.fg import FunctionGraph
from
pytensor.graph.op
import
Op
from
pytensor.graph.rewriting.basic
import
(
NodeProcessingGraphRewriter
,
OpKeyGraphRewriter
,
PatternNodeRewriter
,
SubstitutionNodeRewriter
,
WalkingGraphRewriter
,
...
...
@@ -21,7 +20,7 @@ from tests.unittest_tools import assertFailure_fast
def
OpKeyPatternNodeRewriter
(
p1
,
p2
,
ign
=
True
):
return
OpKey
GraphRewriter
(
PatternNodeRewriter
(
p1
,
p2
),
ignore_newtrees
=
ign
)
return
Walking
GraphRewriter
(
PatternNodeRewriter
(
p1
,
p2
),
ignore_newtrees
=
ign
)
def
TopoSubstitutionNodeRewriter
(
...
...
tests/graph/test_features.py
浏览文件 @
2d46d60e
...
...
@@ -2,92 +2,12 @@ import pytest
import
pytensor.tensor
as
pt
from
pytensor.graph
import
rewrite_graph
from
pytensor.graph.basic
import
Apply
,
Variable
,
equal_computations
from
pytensor.graph.features
import
Feature
,
FullHistory
,
NodeFinder
,
ReplaceValidate
from
pytensor.graph.basic
import
equal_computations
from
pytensor.graph.features
import
Feature
,
FullHistory
,
ReplaceValidate
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
Op
from
pytensor.graph.type
import
Type
from
tests.graph.utils
import
MyVariable
,
op1
class
TestNodeFinder
:
def
test_straightforward
(
self
):
class
MyType
(
Type
):
def
__init__
(
self
,
name
):
self
.
name
=
name
def
filter
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
def
__str__
(
self
):
return
self
.
name
def
__repr__
(
self
):
return
self
.
name
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyType
)
class
MyOp
(
Op
):
__props__
=
(
"nin"
,
"name"
)
def
__init__
(
self
,
nin
,
name
):
self
.
nin
=
nin
self
.
name
=
name
def
make_node
(
self
,
*
inputs
):
def
as_variable
(
x
):
assert
isinstance
(
x
,
Variable
)
return
x
assert
len
(
inputs
)
==
self
.
nin
inputs
=
list
(
map
(
as_variable
,
inputs
))
for
input
in
inputs
:
if
not
isinstance
(
input
.
type
,
MyType
):
raise
Exception
(
"Error 1"
)
outputs
=
[
MyType
(
self
.
name
+
"_R"
)()]
return
Apply
(
self
,
inputs
,
outputs
)
def
__str__
(
self
):
return
self
.
name
def
perform
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
sigmoid
=
MyOp
(
1
,
"Sigmoid"
)
add
=
MyOp
(
2
,
"Add"
)
dot
=
MyOp
(
2
,
"Dot"
)
def
MyVariable
(
name
):
return
Variable
(
MyType
(
name
),
None
,
None
)
def
inputs
():
x
=
MyVariable
(
"x"
)
y
=
MyVariable
(
"y"
)
z
=
MyVariable
(
"z"
)
return
x
,
y
,
z
x
,
y
,
z
=
inputs
()
e0
=
dot
(
y
,
z
)
e
=
add
(
add
(
sigmoid
(
x
),
sigmoid
(
sigmoid
(
z
))),
dot
(
add
(
x
,
y
),
e0
))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
],
clone
=
False
)
g
.
attach_feature
(
NodeFinder
())
assert
hasattr
(
g
,
"get_nodes"
)
for
type
,
num
in
((
add
,
3
),
(
sigmoid
,
3
),
(
dot
,
2
)):
if
len
(
list
(
g
.
get_nodes
(
type
)))
!=
num
:
raise
Exception
(
f
"Expected: {num} times {type}"
)
new_e0
=
add
(
y
,
z
)
assert
e0
.
owner
in
g
.
get_nodes
(
dot
)
assert
new_e0
.
owner
not
in
g
.
get_nodes
(
add
)
g
.
replace
(
e0
,
new_e0
)
assert
e0
.
owner
not
in
g
.
get_nodes
(
dot
)
assert
new_e0
.
owner
in
g
.
get_nodes
(
add
)
for
type
,
num
in
((
add
,
4
),
(
sigmoid
,
3
),
(
dot
,
1
)):
if
len
(
list
(
g
.
get_nodes
(
type
)))
!=
num
:
raise
Exception
(
f
"Expected: {num} times {type}"
)
class
TestReplaceValidate
:
def
test_verbose
(
self
,
capsys
):
var1
=
MyVariable
(
"var1"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论