Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
857cda05
提交
857cda05
authored
7月 07, 2024
作者:
Virgile Andreani
提交者:
Virgile Andreani
7月 09, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor graph/rewriting/utils.py
上级
0484e1e2
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
56 行增加
和
83 行删除
+56
-83
utils.py
pytensor/graph/rewriting/utils.py
+56
-83
没有找到文件。
pytensor/graph/rewriting/utils.py
浏览文件 @
857cda05
...
...
@@ -44,32 +44,23 @@ def rewrite_graph(
"""
from
pytensor.compile
import
optdb
return_fgraph
=
False
if
isinstance
(
graph
,
FunctionGraph
):
fgraph
=
graph
return_fgraph
=
True
else
:
if
isinstance
(
graph
,
list
|
tuple
):
outputs
=
graph
else
:
assert
isinstance
(
graph
,
Variable
)
outputs
=
[
graph
]
outputs
=
[
graph
]
if
isinstance
(
graph
,
Variable
)
else
graph
fgraph
=
FunctionGraph
(
outputs
=
outputs
,
clone
=
clone
)
query_rewrites
=
optdb
.
query
(
RewriteDatabaseQuery
(
include
=
include
,
**
kwargs
))
_
=
query_rewrites
.
rewrite
(
fgraph
)
query_rewrites
.
rewrite
(
fgraph
)
if
custom_rewrite
:
if
custom_rewrite
is
not
None
:
custom_rewrite
.
rewrite
(
fgraph
)
if
return_fgraph
:
if
isinstance
(
graph
,
FunctionGraph
)
:
return
fgraph
else
:
if
isinstance
(
graph
,
list
|
tuple
):
return
fgraph
.
outputs
else
:
return
fgraph
.
outputs
[
0
]
if
isinstance
(
graph
,
Variable
):
return
fgraph
.
outputs
[
0
]
return
fgraph
.
outputs
def
is_same_graph_with_merge
(
...
...
@@ -90,14 +81,10 @@ def is_same_graph_with_merge(
"""
from
pytensor.graph.rewriting.basic
import
MergeOptimizer
if
givens
is
None
:
givens
=
{}
givens
=
dict
(
givens
)
givens
=
{}
if
givens
is
None
else
dict
(
givens
)
# Copy variables since the MergeOptimizer will modify them.
copied
=
copy
.
deepcopy
((
var1
,
var2
,
givens
))
vars
=
copied
[
0
:
2
]
givens
=
copied
[
2
]
*
vars
,
givens
=
copy
.
deepcopy
((
var1
,
var2
,
givens
))
# Create FunctionGraph.
inputs
=
list
(
graph_inputs
(
vars
))
# The clone isn't needed as we did a deepcopy and we cloning will
...
...
@@ -120,8 +107,7 @@ def is_same_graph_with_merge(
# Comparing two single-Variable graphs: they are equal if they are
# the same Variable.
return
vars_replaced
[
0
]
==
vars_replaced
[
1
]
else
:
return
o1
is
o2
return
o1
is
o2
def
is_same_graph
(
...
...
@@ -171,71 +157,58 @@ def is_same_graph(
====== ====== ====== ======
"""
use_equal_computations
=
True
if
givens
is
None
:
givens
=
{}
givens
=
dict
(
givens
)
givens
=
{}
if
givens
is
None
else
dict
(
givens
)
# Get result from the merge-based function.
rval1
=
is_same_graph_with_merge
(
var1
=
var1
,
var2
=
var2
,
givens
=
givens
)
if
givens
:
# We need to build the `in_xs` and `in_ys` lists. To do this, we need
# to be able to tell whether a variable belongs to the computational
# graph of `var1` or `var2`.
# The typical case we want to handle is when `to_replace` belongs to
# one of these graphs, and `replace_by` belongs to the other one. In
# other situations, the current implementation of `equal_computations`
# is probably not appropriate, so we do not call it.
ok
=
True
in_xs
=
[]
in_ys
=
[]
# Compute the sets of all variables found in each computational graph.
inputs_var1
=
graph_inputs
([
var1
])
inputs_var2
=
graph_inputs
([
var2
])
all_vars
=
[
set
(
vars_between
(
v_i
,
v_o
))
for
v_i
,
v_o
in
((
inputs_var1
,
[
var1
]),
(
inputs_var2
,
[
var2
]))
]
def
in_var
(
x
,
k
):
# Return True iff `x` is in computation graph of variable `vark`.
return
x
in
all_vars
[
k
-
1
]
if
not
givens
:
rval2
=
equal_computations
(
xs
=
[
var1
],
ys
=
[
var2
])
assert
rval1
==
rval2
return
rval1
# We need to build the `in_xs` and `in_ys` lists. To do this, we need
# to be able to tell whether a variable belongs to the computational
# graph of `var1` or `var2`.
# The typical case we want to handle is when `to_replace` belongs to
# one of these graphs, and `replace_by` belongs to the other one. In
# other situations, the current implementation of `equal_computations`
# is probably not appropriate, so we do not call it.
use_equal_computations
=
True
in_xs
=
[]
in_ys
=
[]
# Compute the sets of all variables found in each computational graph.
inputs_var1
=
graph_inputs
([
var1
])
inputs_var2
=
graph_inputs
([
var2
])
all_vars1
=
set
(
vars_between
(
inputs_var1
,
[
var1
]))
all_vars2
=
set
(
vars_between
(
inputs_var2
,
[
var2
]))
for
to_replace
,
replace_by
in
givens
.
items
():
# Map a substitution variable to the computational graphs it
# belongs to.
inside
=
{
v
:
[
in_var
(
v
,
k
)
for
k
in
(
1
,
2
)]
for
v
in
(
to_replace
,
replace_by
)
}
if
(
inside
[
to_replace
][
0
]
and
not
inside
[
to_replace
][
1
]
and
inside
[
replace_by
][
1
]
and
not
inside
[
replace_by
][
0
]
):
# Substitute variable in `var1` by one from `var2`.
in_xs
.
append
(
to_replace
)
in_ys
.
append
(
replace_by
)
elif
(
inside
[
to_replace
][
1
]
and
not
inside
[
to_replace
][
0
]
and
inside
[
replace_by
][
0
]
and
not
inside
[
replace_by
][
1
]
):
# Substitute variable in `var2` by one from `var1`.
in_xs
.
append
(
replace_by
)
in_ys
.
append
(
to_replace
)
else
:
ok
=
False
break
if
not
ok
:
# We cannot directly use `equal_computations`.
for
to_replace
,
replace_by
in
givens
.
items
():
# Map a substitution variable to the computational graphs it
# belongs to.
inside
=
{
v
:
[
v
in
all_vars1
,
v
in
all_vars2
]
for
v
in
(
to_replace
,
replace_by
)}
if
(
inside
[
to_replace
][
0
]
and
not
inside
[
to_replace
][
1
]
and
inside
[
replace_by
][
1
]
and
not
inside
[
replace_by
][
0
]
):
# Substitute variable in `var1` by one from `var2`.
in_xs
.
append
(
to_replace
)
in_ys
.
append
(
replace_by
)
elif
(
inside
[
to_replace
][
1
]
and
not
inside
[
to_replace
][
0
]
and
inside
[
replace_by
][
0
]
and
not
inside
[
replace_by
][
1
]
):
# Substitute variable in `var2` by one from `var1`.
in_xs
.
append
(
replace_by
)
in_ys
.
append
(
to_replace
)
else
:
use_equal_computations
=
False
else
:
in_xs
=
None
in_ys
=
None
break
if
use_equal_computations
:
rval2
=
equal_computations
(
xs
=
[
var1
],
ys
=
[
var2
],
in_xs
=
in_xs
,
in_ys
=
in_ys
)
assert
rval2
==
rval1
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论