Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2d46d60e
提交
2d46d60e
authored
8月 27, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
8月 28, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove unused rewrites and functionality
上级
40ccab1a
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
10 行增加
和
196 行删除
+10
-196
graph_rewriting.rst
doc/extending/graph_rewriting.rst
+1
-1
features.rst
doc/library/graph/features.rst
+0
-4
features.py
pytensor/graph/features.py
+0
-94
basic.py
pytensor/graph/rewriting/basic.py
+0
-0
basic.py
pytensor/tensor/rewriting/basic.py
+4
-2
math.py
pytensor/tensor/rewriting/math.py
+0
-9
test_types.py
tests/compile/function/test_types.py
+2
-2
test_basic.py
tests/graph/rewriting/test_basic.py
+0
-0
test_destroyhandler.py
tests/graph/test_destroyhandler.py
+1
-2
test_features.py
tests/graph/test_features.py
+2
-82
没有找到文件。
doc/extending/graph_rewriting.rst
浏览文件 @
2d46d60e
...
@@ -134,7 +134,7 @@ computation graph.
...
@@ -134,7 +134,7 @@ computation graph.
In a nutshell, :class:`ReplaceValidate` grants access to :meth:`fgraph.replace_validate`,
In a nutshell, :class:`ReplaceValidate` grants access to :meth:`fgraph.replace_validate`,
and :meth:`fgraph.replace_validate` allows us to replace a :class:`Variable` with
and :meth:`fgraph.replace_validate` allows us to replace a :class:`Variable` with
another while respecting certain validation constraints. As an
another while respecting certain validation constraints. As an
exercise, try to rewrite :class:`Simplify` using :class:`
NodeFind
er`. (Hint: you
exercise, try to rewrite :class:`Simplify` using :class:`
WalkingGraphRewrit
er`. (Hint: you
want to use the method it publishes instead of the call to toposort)
want to use the method it publishes instead of the call to toposort)
Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by
Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by
...
...
doc/library/graph/features.rst
浏览文件 @
2d46d60e
...
@@ -26,7 +26,3 @@ Guide
...
@@ -26,7 +26,3 @@ Guide
.. class:: ReplaceValidate(History, Validator)
.. class:: ReplaceValidate(History, Validator)
.. method:: replace_validate(fgraph, var, new_var, reason=None)
.. method:: replace_validate(fgraph, var, new_var, reason=None)
.. class:: NodeFinder(Bookkeeper)
.. class:: PrintListener(object)
pytensor/graph/features.py
浏览文件 @
2d46d60e
...
@@ -827,100 +827,6 @@ class ReplaceValidate(History, Validator):
...
@@ -827,100 +827,6 @@ class ReplaceValidate(History, Validator):
raise
InconsistencyError
(
"Trying to reintroduce a removed node"
)
raise
InconsistencyError
(
"Trying to reintroduce a removed node"
)
class
NodeFinder
(
Bookkeeper
):
def
__init__
(
self
):
self
.
fgraph
=
None
self
.
d
=
{}
def
on_attach
(
self
,
fgraph
):
if
hasattr
(
fgraph
,
"get_nodes"
):
raise
AlreadyThere
(
"NodeFinder is already present"
)
if
self
.
fgraph
is
not
None
and
self
.
fgraph
!=
fgraph
:
raise
Exception
(
"A NodeFinder instance can only serve one FunctionGraph."
)
self
.
fgraph
=
fgraph
fgraph
.
get_nodes
=
partial
(
self
.
query
,
fgraph
)
Bookkeeper
.
on_attach
(
self
,
fgraph
)
def
clone
(
self
):
return
type
(
self
)()
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if
self
.
fgraph
is
not
fgraph
:
raise
Exception
(
"This NodeFinder instance was not attached to the provided fgraph."
)
self
.
fgraph
=
None
del
fgraph
.
get_nodes
Bookkeeper
.
on_detach
(
self
,
fgraph
)
def
on_import
(
self
,
fgraph
,
node
,
reason
):
try
:
self
.
d
.
setdefault
(
node
.
op
,
[])
.
append
(
node
)
except
TypeError
:
# node.op is unhashable
return
except
Exception
as
e
:
print
(
"OFFENDING node"
,
type
(
node
),
type
(
node
.
op
),
file
=
sys
.
stderr
)
# noqa: T201
try
:
print
(
"OFFENDING node hash"
,
hash
(
node
.
op
),
file
=
sys
.
stderr
)
# noqa: T201
except
Exception
:
print
(
"OFFENDING node not hashable"
,
file
=
sys
.
stderr
)
# noqa: T201
raise
e
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
try
:
nodes
=
self
.
d
[
node
.
op
]
except
TypeError
:
# node.op is unhashable
return
nodes
.
remove
(
node
)
if
not
nodes
:
del
self
.
d
[
node
.
op
]
def
query
(
self
,
fgraph
,
op
):
try
:
all
=
self
.
d
.
get
(
op
,
[])
except
TypeError
:
raise
TypeError
(
f
"{op} in unhashable and cannot be queried by the optimizer"
)
all
=
list
(
all
)
return
all
class
PrintListener
(
Feature
):
def
__init__
(
self
,
active
=
True
):
self
.
active
=
active
def
on_attach
(
self
,
fgraph
):
if
self
.
active
:
print
(
"-- attaching to: "
,
fgraph
)
# noqa: T201
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if
self
.
active
:
print
(
"-- detaching from: "
,
fgraph
)
# noqa: T201
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
print
(
f
"-- importing: {node}, reason: {reason}"
)
# noqa: T201
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
print
(
f
"-- pruning: {node}, reason: {reason}"
)
# noqa: T201
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
if
self
.
active
:
print
(
f
"-- changing ({node}.inputs[{i}]) from {r} to {new_r}"
)
# noqa: T201
class
PreserveVariableAttributes
(
Feature
):
class
PreserveVariableAttributes
(
Feature
):
"""
"""
This preserve some variables attributes and tag during optimization.
This preserve some variables attributes and tag during optimization.
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
2d46d60e
差异被折叠。
点击展开。
pytensor/tensor/rewriting/basic.py
浏览文件 @
2d46d60e
...
@@ -34,7 +34,6 @@ from pytensor.graph.basic import Constant
...
@@ -34,7 +34,6 @@ from pytensor.graph.basic import Constant
from
pytensor.graph.rewriting.basic
import
(
from
pytensor.graph.rewriting.basic
import
(
NodeProcessingGraphRewriter
,
NodeProcessingGraphRewriter
,
NodeRewriter
,
NodeRewriter
,
RemovalNodeRewriter
,
Rewriter
,
Rewriter
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
in2out
,
...
@@ -1224,7 +1223,10 @@ def local_merge_alloc(fgraph, node):
...
@@ -1224,7 +1223,10 @@ def local_merge_alloc(fgraph, node):
return
[
alloc
(
inputs_inner
[
0
],
*
dims_outer
)]
return
[
alloc
(
inputs_inner
[
0
],
*
dims_outer
)]
register_canonicalize
(
RemovalNodeRewriter
(
tensor_copy
),
name
=
"remove_tensor_copy"
)
@register_canonicalize
@node_rewriter
(
tracks
=
[
tensor_copy
])
def
remove_tensor_copy
(
fgraph
,
node
):
return
node
.
inputs
@register_specialize
@register_specialize
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
2d46d60e
...
@@ -3162,13 +3162,6 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
...
@@ -3162,13 +3162,6 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
return
np
.
allclose
(
x
,
ref
,
rtol
=
rtol
,
atol
=
atol
)
return
np
.
allclose
(
x
,
ref
,
rtol
=
rtol
,
atol
=
atol
)
def
_skip_mul_1
(
r
):
if
r
.
owner
and
r
.
owner
.
op
==
mul
:
not_is_1
=
[
i
for
i
in
r
.
owner
.
inputs
if
not
_is_1
(
i
)]
if
len
(
not_is_1
)
==
1
:
return
not_is_1
[
0
]
def
_is_1
(
expr
):
def
_is_1
(
expr
):
"""
"""
...
@@ -3190,7 +3183,6 @@ logsigm_to_softplus = PatternNodeRewriter(
...
@@ -3190,7 +3183,6 @@ logsigm_to_softplus = PatternNodeRewriter(
(
neg
,
(
softplus
,
(
neg
,
"x"
))),
(
neg
,
(
softplus
,
(
neg
,
"x"
))),
allow_multiple_clients
=
True
,
allow_multiple_clients
=
True
,
values_eq_approx
=
values_eq_approx_remove_inf
,
values_eq_approx
=
values_eq_approx_remove_inf
,
skip_identities_fn
=
_skip_mul_1
,
tracks
=
[
sigmoid
],
tracks
=
[
sigmoid
],
get_nodes
=
get_clients_at_depth1
,
get_nodes
=
get_clients_at_depth1
,
)
)
...
@@ -3199,7 +3191,6 @@ log1msigm_to_softplus = PatternNodeRewriter(
...
@@ -3199,7 +3191,6 @@ log1msigm_to_softplus = PatternNodeRewriter(
(
neg
,
(
softplus
,
"x"
)),
(
neg
,
(
softplus
,
"x"
)),
allow_multiple_clients
=
True
,
allow_multiple_clients
=
True
,
values_eq_approx
=
values_eq_approx_remove_inf
,
values_eq_approx
=
values_eq_approx_remove_inf
,
skip_identities_fn
=
_skip_mul_1
,
tracks
=
[
sigmoid
],
tracks
=
[
sigmoid
],
get_nodes
=
get_clients_at_depth2
,
get_nodes
=
get_clients_at_depth2
,
)
)
...
...
tests/compile/function/test_types.py
浏览文件 @
2d46d60e
...
@@ -13,7 +13,7 @@ from pytensor.compile.io import In, Out
...
@@ -13,7 +13,7 @@ from pytensor.compile.io import In, Out
from
pytensor.compile.mode
import
Mode
,
get_default_mode
from
pytensor.compile.mode
import
Mode
,
get_default_mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.rewriting.basic
import
OpKeyGraphRewriter
,
PatternNode
Rewriter
from
pytensor.graph.rewriting.basic
import
PatternNodeRewriter
,
WalkingGraph
Rewriter
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.link.vm
import
VMLinker
from
pytensor.link.vm
import
VMLinker
from
pytensor.printing
import
debugprint
from
pytensor.printing
import
debugprint
...
@@ -39,7 +39,7 @@ pytestmark = pytest.mark.filterwarnings("error")
...
@@ -39,7 +39,7 @@ pytestmark = pytest.mark.filterwarnings("error")
def
PatternOptimizer
(
p1
,
p2
,
ign
=
True
):
def
PatternOptimizer
(
p1
,
p2
,
ign
=
True
):
return
OpKey
GraphRewriter
(
PatternNodeRewriter
(
p1
,
p2
),
ignore_newtrees
=
ign
)
return
Walking
GraphRewriter
(
PatternNodeRewriter
(
p1
,
p2
),
ignore_newtrees
=
ign
)
class
TestFunction
:
class
TestFunction
:
...
...
tests/graph/rewriting/test_basic.py
浏览文件 @
2d46d60e
差异被折叠。
点击展开。
tests/graph/test_destroyhandler.py
浏览文件 @
2d46d60e
...
@@ -10,7 +10,6 @@ from pytensor.graph.fg import FunctionGraph
...
@@ -10,7 +10,6 @@ from pytensor.graph.fg import FunctionGraph
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.rewriting.basic
import
(
from
pytensor.graph.rewriting.basic
import
(
NodeProcessingGraphRewriter
,
NodeProcessingGraphRewriter
,
OpKeyGraphRewriter
,
PatternNodeRewriter
,
PatternNodeRewriter
,
SubstitutionNodeRewriter
,
SubstitutionNodeRewriter
,
WalkingGraphRewriter
,
WalkingGraphRewriter
,
...
@@ -21,7 +20,7 @@ from tests.unittest_tools import assertFailure_fast
...
@@ -21,7 +20,7 @@ from tests.unittest_tools import assertFailure_fast
def
OpKeyPatternNodeRewriter
(
p1
,
p2
,
ign
=
True
):
def
OpKeyPatternNodeRewriter
(
p1
,
p2
,
ign
=
True
):
return
OpKey
GraphRewriter
(
PatternNodeRewriter
(
p1
,
p2
),
ignore_newtrees
=
ign
)
return
Walking
GraphRewriter
(
PatternNodeRewriter
(
p1
,
p2
),
ignore_newtrees
=
ign
)
def
TopoSubstitutionNodeRewriter
(
def
TopoSubstitutionNodeRewriter
(
...
...
tests/graph/test_features.py
浏览文件 @
2d46d60e
...
@@ -2,92 +2,12 @@ import pytest
...
@@ -2,92 +2,12 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.graph
import
rewrite_graph
from
pytensor.graph
import
rewrite_graph
from
pytensor.graph.basic
import
Apply
,
Variable
,
equal_computations
from
pytensor.graph.basic
import
equal_computations
from
pytensor.graph.features
import
Feature
,
FullHistory
,
NodeFinder
,
ReplaceValidate
from
pytensor.graph.features
import
Feature
,
FullHistory
,
ReplaceValidate
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
Op
from
pytensor.graph.type
import
Type
from
tests.graph.utils
import
MyVariable
,
op1
from
tests.graph.utils
import
MyVariable
,
op1
class
TestNodeFinder
:
def
test_straightforward
(
self
):
class
MyType
(
Type
):
def
__init__
(
self
,
name
):
self
.
name
=
name
def
filter
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
def
__str__
(
self
):
return
self
.
name
def
__repr__
(
self
):
return
self
.
name
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyType
)
class
MyOp
(
Op
):
__props__
=
(
"nin"
,
"name"
)
def
__init__
(
self
,
nin
,
name
):
self
.
nin
=
nin
self
.
name
=
name
def
make_node
(
self
,
*
inputs
):
def
as_variable
(
x
):
assert
isinstance
(
x
,
Variable
)
return
x
assert
len
(
inputs
)
==
self
.
nin
inputs
=
list
(
map
(
as_variable
,
inputs
))
for
input
in
inputs
:
if
not
isinstance
(
input
.
type
,
MyType
):
raise
Exception
(
"Error 1"
)
outputs
=
[
MyType
(
self
.
name
+
"_R"
)()]
return
Apply
(
self
,
inputs
,
outputs
)
def
__str__
(
self
):
return
self
.
name
def
perform
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
sigmoid
=
MyOp
(
1
,
"Sigmoid"
)
add
=
MyOp
(
2
,
"Add"
)
dot
=
MyOp
(
2
,
"Dot"
)
def
MyVariable
(
name
):
return
Variable
(
MyType
(
name
),
None
,
None
)
def
inputs
():
x
=
MyVariable
(
"x"
)
y
=
MyVariable
(
"y"
)
z
=
MyVariable
(
"z"
)
return
x
,
y
,
z
x
,
y
,
z
=
inputs
()
e0
=
dot
(
y
,
z
)
e
=
add
(
add
(
sigmoid
(
x
),
sigmoid
(
sigmoid
(
z
))),
dot
(
add
(
x
,
y
),
e0
))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
],
clone
=
False
)
g
.
attach_feature
(
NodeFinder
())
assert
hasattr
(
g
,
"get_nodes"
)
for
type
,
num
in
((
add
,
3
),
(
sigmoid
,
3
),
(
dot
,
2
)):
if
len
(
list
(
g
.
get_nodes
(
type
)))
!=
num
:
raise
Exception
(
f
"Expected: {num} times {type}"
)
new_e0
=
add
(
y
,
z
)
assert
e0
.
owner
in
g
.
get_nodes
(
dot
)
assert
new_e0
.
owner
not
in
g
.
get_nodes
(
add
)
g
.
replace
(
e0
,
new_e0
)
assert
e0
.
owner
not
in
g
.
get_nodes
(
dot
)
assert
new_e0
.
owner
in
g
.
get_nodes
(
add
)
for
type
,
num
in
((
add
,
4
),
(
sigmoid
,
3
),
(
dot
,
1
)):
if
len
(
list
(
g
.
get_nodes
(
type
)))
!=
num
:
raise
Exception
(
f
"Expected: {num} times {type}"
)
class
TestReplaceValidate
:
class
TestReplaceValidate
:
def
test_verbose
(
self
,
capsys
):
def
test_verbose
(
self
,
capsys
):
var1
=
MyVariable
(
"var1"
)
var1
=
MyVariable
(
"var1"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论