Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8ad33179
提交
8ad33179
authored
12月 09, 2022
作者:
Maxim Kochurov
提交者:
Ricardo Vieira
12月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor is_in_ancestors to support multiple inputs
上级
38731adb
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
27 行增加
和
19 行删除
+27
-19
basic.py
pytensor/graph/basic.py
+12
-8
ifelse.py
pytensor/ifelse.py
+4
-4
rewriting.py
pytensor/scan/rewriting.py
+3
-3
test_basic.py
tests/graph/test_basic.py
+8
-4
没有找到文件。
pytensor/graph/basic.py
浏览文件 @
8ad33179
...
@@ -1568,15 +1568,15 @@ def list_of_nodes(
...
@@ -1568,15 +1568,15 @@ def list_of_nodes(
)
)
def
is_in_ancestors
(
l_apply
:
Apply
,
f_apply
:
Apply
)
->
bool
:
def
apply_depends_on
(
apply
:
Apply
,
depends_on
:
Union
[
Apply
,
Collection
[
Apply
]]
)
->
bool
:
"""Determine if
`f_apply` is in the graph given by `l_apply
`.
"""Determine if
any `depends_on` is in the graph given by ``apply`
`.
Parameters
Parameters
----------
----------
l_
apply : Apply
apply : Apply
The
node to wal
k.
The
Apply node to chec
k.
f_apply : Apply
depends_on : Union[Apply, Collection[Apply]]
The node to find in `l_apply`.
Apply nodes to check dependency on
Returns
Returns
-------
-------
...
@@ -1584,14 +1584,18 @@ def is_in_ancestors(l_apply: Apply, f_apply: Apply) -> bool:
...
@@ -1584,14 +1584,18 @@ def is_in_ancestors(l_apply: Apply, f_apply: Apply) -> bool:
"""
"""
computed
=
set
()
computed
=
set
()
todo
=
[
l_apply
]
todo
=
[
apply
]
if
not
isinstance
(
depends_on
,
Collection
):
depends_on
=
{
depends_on
}
else
:
depends_on
=
set
(
depends_on
)
while
todo
:
while
todo
:
cur
=
todo
.
pop
()
cur
=
todo
.
pop
()
if
cur
.
outputs
[
0
]
in
computed
:
if
cur
.
outputs
[
0
]
in
computed
:
continue
continue
if
all
(
i
in
computed
or
i
.
owner
is
None
for
i
in
cur
.
inputs
):
if
all
(
i
in
computed
or
i
.
owner
is
None
for
i
in
cur
.
inputs
):
computed
.
update
(
cur
.
outputs
)
computed
.
update
(
cur
.
outputs
)
if
cur
i
s
f_apply
:
if
cur
i
n
depends_on
:
return
True
return
True
else
:
else
:
todo
.
append
(
cur
)
todo
.
append
(
cur
)
...
...
pytensor/ifelse.py
浏览文件 @
8ad33179
...
@@ -20,7 +20,7 @@ import pytensor.tensor as at
...
@@ -20,7 +20,7 @@ import pytensor.tensor as at
from
pytensor
import
as_symbolic
from
pytensor
import
as_symbolic
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
Variable
,
is_in_ancestors
from
pytensor.graph.basic
import
Apply
,
Variable
,
apply_depends_on
from
pytensor.graph.op
import
_NoPythonOp
from
pytensor.graph.op
import
_NoPythonOp
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.rewriting.basic
import
GraphRewriter
,
in2out
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
GraphRewriter
,
in2out
,
node_rewriter
...
@@ -604,7 +604,7 @@ class CondMerge(GraphRewriter):
...
@@ -604,7 +604,7 @@ class CondMerge(GraphRewriter):
return
False
return
False
merging_node
=
cond_nodes
[
0
]
merging_node
=
cond_nodes
[
0
]
for
proposal
in
cond_nodes
[
1
:]:
for
proposal
in
cond_nodes
[
1
:]:
if
proposal
.
inputs
[
0
]
==
merging_node
.
inputs
[
0
]
and
not
is_in_ancestors
(
if
proposal
.
inputs
[
0
]
==
merging_node
.
inputs
[
0
]
and
not
apply_depends_on
(
proposal
,
merging_node
proposal
,
merging_node
):
):
# Create a list of replacements for proposal
# Create a list of replacements for proposal
...
@@ -704,8 +704,8 @@ def cond_merge_random_op(fgraph, main_node):
...
@@ -704,8 +704,8 @@ def cond_merge_random_op(fgraph, main_node):
for
proposal
in
cond_nodes
[
1
:]:
for
proposal
in
cond_nodes
[
1
:]:
if
(
if
(
proposal
.
inputs
[
0
]
==
merging_node
.
inputs
[
0
]
proposal
.
inputs
[
0
]
==
merging_node
.
inputs
[
0
]
and
not
is_in_ancestors
(
proposal
,
merging_node
)
and
not
apply_depends_on
(
proposal
,
merging_node
)
and
not
is_in_ancestors
(
merging_node
,
proposal
)
and
not
apply_depends_on
(
merging_node
,
proposal
)
):
):
# Create a list of replacements for proposal
# Create a list of replacements for proposal
mn_ts
=
merging_node
.
inputs
[
1
:][:
merging_node
.
op
.
n_outs
]
mn_ts
=
merging_node
.
inputs
[
1
:][:
merging_node
.
op
.
n_outs
]
...
...
pytensor/scan/rewriting.py
浏览文件 @
8ad33179
...
@@ -18,10 +18,10 @@ from pytensor.graph.basic import (
...
@@ -18,10 +18,10 @@ from pytensor.graph.basic import (
Apply
,
Apply
,
Constant
,
Constant
,
Variable
,
Variable
,
apply_depends_on
,
equal_computations
,
equal_computations
,
graph_inputs
,
graph_inputs
,
io_toposort
,
io_toposort
,
is_in_ancestors
,
)
)
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.features
import
ReplaceValidate
...
@@ -1642,7 +1642,7 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1642,7 +1642,7 @@ def save_mem_new_scan(fgraph, node):
old_new
+=
[(
o
,
new_outs
[
nw_pos
])]
old_new
+=
[(
o
,
new_outs
[
nw_pos
])]
# Check if the new outputs depend on the old scan node
# Check if the new outputs depend on the old scan node
old_scan_is_used
=
[
old_scan_is_used
=
[
is_in_ancestors
(
new
.
owner
,
node
)
for
old
,
new
in
old_new
apply_depends_on
(
new
.
owner
,
node
)
for
old
,
new
in
old_new
]
]
if
any
(
old_scan_is_used
):
if
any
(
old_scan_is_used
):
return
False
return
False
...
@@ -1877,7 +1877,7 @@ class ScanMerge(GraphRewriter):
...
@@ -1877,7 +1877,7 @@ class ScanMerge(GraphRewriter):
# Check to see if it is an input of a different node
# Check to see if it is an input of a different node
for
nd
in
set_nodes
:
for
nd
in
set_nodes
:
if
is_in_ancestors
(
node
,
nd
)
or
is_in_ancestors
(
nd
,
node
):
if
apply_depends_on
(
node
,
nd
)
or
apply_depends_on
(
nd
,
node
):
return
False
return
False
if
not
node
.
op
.
info
.
as_while
:
if
not
node
.
op
.
info
.
as_while
:
...
...
tests/graph/test_basic.py
浏览文件 @
8ad33179
...
@@ -11,6 +11,7 @@ from pytensor.graph.basic import (
...
@@ -11,6 +11,7 @@ from pytensor.graph.basic import (
NominalVariable
,
NominalVariable
,
Variable
,
Variable
,
ancestors
,
ancestors
,
apply_depends_on
,
applys_between
,
applys_between
,
as_string
,
as_string
,
clone
,
clone
,
...
@@ -20,7 +21,6 @@ from pytensor.graph.basic import (
...
@@ -20,7 +21,6 @@ from pytensor.graph.basic import (
get_var_by_name
,
get_var_by_name
,
graph_inputs
,
graph_inputs
,
io_toposort
,
io_toposort
,
is_in_ancestors
,
list_of_nodes
,
list_of_nodes
,
orphans_between
,
orphans_between
,
vars_between
,
vars_between
,
...
@@ -491,15 +491,19 @@ def test_list_of_nodes():
...
@@ -491,15 +491,19 @@ def test_list_of_nodes():
assert
res
==
[
o2
.
owner
,
o1
.
owner
]
assert
res
==
[
o2
.
owner
,
o1
.
owner
]
def
test_
is_in_ancestors
():
def
test_
apply_depends_on
():
r1
,
r2
,
r3
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
)
r1
,
r2
,
r3
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
)
o1
=
MyOp
(
r1
,
r2
)
o1
=
MyOp
(
r1
,
r2
)
o1
.
name
=
"o1"
o1
.
name
=
"o1"
o2
=
MyOp
(
r
3
,
o1
)
o2
=
MyOp
(
r
1
,
o1
)
o2
.
name
=
"o2"
o2
.
name
=
"o2"
o3
=
MyOp
(
r3
,
o1
,
o2
)
o3
.
name
=
"o3"
assert
is_in_ancestors
(
o2
.
owner
,
o1
.
owner
)
assert
apply_depends_on
(
o2
.
owner
,
o1
.
owner
)
assert
apply_depends_on
(
o2
.
owner
,
o2
.
owner
)
assert
apply_depends_on
(
o3
.
owner
,
[
o1
.
owner
,
o2
.
owner
])
@pytest.mark.xfail
(
reason
=
"Not implemented"
)
@pytest.mark.xfail
(
reason
=
"Not implemented"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论