Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8a040b98
提交
8a040b98
authored
6月 28, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
6月 11, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add feature that keeps track of full rewrite history
上级
646a734d
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
197 行增加
和
2 行删除
+197
-2
features.py
pytensor/graph/features.py
+163
-0
test_features.py
tests/graph/test_features.py
+34
-2
没有找到文件。
pytensor/graph/features.py
浏览文件 @
8a040b98
...
...
@@ -438,6 +438,169 @@ class History(Feature):
self
.
history
[
fgraph
]
=
h
class
FullHistory
(
Feature
):
"""Keeps track of all changes in FunctionGraph and allows arbitrary back and forth through intermediate states
.. testcode::
import pytensor
import pytensor.tensor as pt
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.features import FullHistory
from pytensor.graph.rewriting.utils import rewrite_graph
x = pt.scalar("x")
out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
fg = FunctionGraph(outputs=[out])
history = FullHistory()
fg.attach_feature(history)
rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
# Replay rewrites
history.start()
pytensor.dprint(fg)
with pytensor.config.change_flags(optimizer_verbose = True):
for i in range(3):
print(">> ", end="")
pytensor.dprint(history.next())
.. testoutput::
Log [id A] 4
└─ True_div [id B] 3
├─ Exp [id C] 2
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id F] 0
└─ x [id D]
>> MergeOptimizer
Log [id A] 3
└─ True_div [id B] 2
├─ Exp [id C] 0
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id C] 0
└─ ···
>> local_mul_canonizer
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
>> local_logsoftmax
LogSoftmax{axis=None} [id A] 0
└─ x [id B]
.. testcode::
# Or in reverse
with pytensor.config.change_flags(optimizer_verbose=True):
for i in range(3):
print(">> ", end="")
pytensor.dprint(history.prev())
.. testoutput::
>> local_logsoftmax
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
>> local_mul_canonizer
Log [id A] 3
└─ True_div [id B] 2
├─ Exp [id C] 0
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id C] 0
└─ ···
>> MergeOptimizer
Log [id A] 4
└─ True_div [id B] 3
├─ Exp [id C] 2
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id F] 0
└─ x [id D]
.. testcode::
# Or go to any step
pytensor.dprint(history.goto(2))
.. testoutput::
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
"""
def
__init__
(
self
):
self
.
fw
=
[]
self
.
bw
=
[]
self
.
pointer
=
-
1
self
.
fg
=
None
def
on_attach
(
self
,
fgraph
):
if
self
.
fg
is
not
None
:
raise
ValueError
(
"Full History already attached to another fgraph"
)
self
.
fg
=
fgraph
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
self
.
bw
.
append
(
LambdaExtract
(
fgraph
,
node
,
i
,
r
,
reason
))
self
.
fw
.
append
(
LambdaExtract
(
fgraph
,
node
,
i
,
new_r
,
reason
))
self
.
pointer
+=
1
def
goto
(
self
,
checkpoint
):
"""
Reverts the graph to whatever it was at the provided
checkpoint (undoes all replacements). A checkpoint at any
given time can be obtained using self.checkpoint().
"""
history_len
=
len
(
self
.
bw
)
pointer
=
self
.
pointer
assert
0
<=
checkpoint
<=
history_len
verbose
=
config
.
optimizer_verbose
# Go backwards
while
pointer
>
checkpoint
-
1
:
reverse_fn
=
self
.
bw
[
pointer
]
if
verbose
:
print
(
reverse_fn
.
reason
)
# noqa: T201
reverse_fn
()
pointer
-=
1
# Go forward
while
pointer
<
checkpoint
-
1
:
pointer
+=
1
forward_fn
=
self
.
fw
[
pointer
]
if
verbose
:
print
(
forward_fn
.
reason
)
# noqa: T201
forward_fn
()
# Remove history changes caused by the foward/backward!
self
.
bw
=
self
.
bw
[:
history_len
]
self
.
fw
=
self
.
fw
[:
history_len
]
self
.
pointer
=
pointer
return
self
.
fg
def
start
(
self
):
return
self
.
goto
(
0
)
def
end
(
self
):
return
self
.
goto
(
len
(
self
.
bw
))
def
prev
(
self
):
if
self
.
pointer
<
0
:
return
self
.
fg
else
:
return
self
.
goto
(
self
.
pointer
)
def
next
(
self
):
if
self
.
pointer
>=
len
(
self
.
bw
)
-
1
:
return
self
.
fg
else
:
return
self
.
goto
(
self
.
pointer
+
2
)
class
Validator
(
Feature
):
pickle_rm_attr
=
[
"validate"
,
"consistent"
]
...
...
tests/graph/test_features.py
浏览文件 @
8a040b98
import
pytest
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.features
import
Feature
,
NodeFinder
,
ReplaceValidate
import
pytensor.tensor
as
pt
from
pytensor.graph
import
rewrite_graph
from
pytensor.graph.basic
import
Apply
,
Variable
,
equal_computations
from
pytensor.graph.features
import
Feature
,
FullHistory
,
NodeFinder
,
ReplaceValidate
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
Op
from
pytensor.graph.type
import
Type
...
...
@@ -119,3 +121,33 @@ class TestReplaceValidate:
capres
=
capsys
.
readouterr
()
assert
"rewriting: validate failed on node Op1.0"
in
capres
.
out
def
test_full_history
():
x
=
pt
.
scalar
(
"x"
)
out
=
pt
.
log
(
pt
.
exp
(
x
)
/
pt
.
sum
(
pt
.
exp
(
x
)))
fg
=
FunctionGraph
(
outputs
=
[
out
],
clone
=
True
,
copy_inputs
=
False
)
history
=
FullHistory
()
fg
.
attach_feature
(
history
)
rewrite_graph
(
fg
,
clone
=
False
,
include
=
(
"canonicalize"
,
"stabilize"
))
history
.
start
()
assert
equal_computations
(
fg
.
outputs
,
[
out
])
history
.
end
()
assert
equal_computations
(
fg
.
outputs
,
[
pt
.
special
.
log_softmax
(
x
)])
history
.
prev
()
assert
equal_computations
(
fg
.
outputs
,
[
pt
.
log
(
pt
.
special
.
softmax
(
x
))])
for
i
in
range
(
10
):
history
.
prev
()
assert
equal_computations
(
fg
.
outputs
,
[
out
])
history
.
goto
(
2
)
assert
equal_computations
(
fg
.
outputs
,
[
pt
.
log
(
pt
.
special
.
softmax
(
x
))])
for
i
in
range
(
10
):
history
.
next
()
assert
equal_computations
(
fg
.
outputs
,
[
pt
.
special
.
log_softmax
(
x
)])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论