Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e9a7d7ce
Unverified
提交
e9a7d7ce
authored
7月 26, 2023
作者:
Maxim Kochurov
提交者:
GitHub
7月 26, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Speedup `truncated_graph_inputs` (#394)
* add pytest-mock dependency * rename to node to variable
上级
673c1acc
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
57 行增加
和
38 行删除
+57
-38
test.yml
.github/workflows/test.yml
+1
-1
environment.yml
environment.yml
+1
-0
pyproject.toml
pyproject.toml
+1
-0
basic.py
pytensor/graph/basic.py
+39
-37
test_basic.py
tests/graph/test_basic.py
+15
-0
没有找到文件。
.github/workflows/test.yml
浏览文件 @
e9a7d7ce
...
...
@@ -139,7 +139,7 @@ jobs:
-
name
:
Install dependencies
shell
:
bash -l {0}
run
:
|
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark
pytest-mock
sympy
# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but
# not numpy, even though scipy 1.7 requires numpy<1.23. When installing
# PyTensor next, pip installs a lower version of numpy via the PyPI.
...
...
environment.yml
浏览文件 @
e9a7d7ce
...
...
@@ -31,6 +31,7 @@ dependencies:
-
pytest-cov
-
pytest-xdist
-
pytest-benchmark
-
pytest-mock
# For building docs
-
sphinx>=5.1.0,<6
-
sphinx_rtd_theme
...
...
pyproject.toml
浏览文件 @
e9a7d7ce
...
...
@@ -86,6 +86,7 @@ tests = [
"pytest-cov>=2.6.1"
,
"coverage>=5.1"
,
"pytest-benchmark"
,
"pytest-mock"
,
]
rtd
=
[
"sphinx>=5.1.0,<6"
,
...
...
pytensor/graph/basic.py
浏览文件 @
e9a7d7ce
...
...
@@ -1003,14 +1003,14 @@ def applys_between(
def
truncated_graph_inputs
(
outputs
:
Sequence
[
Variable
],
ancestors_to_include
:
Optional
[
Collection
[
Variable
]]
=
None
,
)
->
L
ist
[
Variable
]:
)
->
l
ist
[
Variable
]:
"""Get the truncate graph inputs.
Unlike :func:`graph_inputs` this function will return
the closest
nod
es to outputs that do not depend on
the closest
variabl
es to outputs that do not depend on
``ancestors_to_include``. So given all the returned
variables provided there is no missing
nod
e to
compute the output and all
nod
es are independent
variables provided there is no missing
variabl
e to
compute the output and all
variabl
es are independent
from each other.
Parameters
...
...
@@ -1027,7 +1027,7 @@ def truncated_graph_inputs(
Examples
--------
The returned
nodes marked in (parenthesis), ancestors nodes are ``c``, output nod
es are ``o``
The returned
variables marked in (parenthesis), ancestors variables are ``c``, output variabl
es are ``o``
* No ancestors to include
...
...
@@ -1047,7 +1047,7 @@ def truncated_graph_inputs(
(c) - (c) - o
* Additional
nod
es are present
* Additional
variabl
es are present
.. code-block::
...
...
@@ -1076,58 +1076,60 @@ def truncated_graph_inputs(
"""
# simple case, no additional ancestors to include
truncated_inputs
=
list
()
# blockers have known independent
nod
es and ancestors to include
truncated_inputs
:
list
[
Variable
]
=
list
()
# blockers have known independent
variabl
es and ancestors to include
candidates
=
list
(
outputs
)
if
not
ancestors_to_include
:
# None or empty
# just filter out unique variables
for
nod
e
in
candidates
:
if
nod
e
not
in
truncated_inputs
:
truncated_inputs
.
append
(
nod
e
)
for
variabl
e
in
candidates
:
if
variabl
e
not
in
truncated_inputs
:
truncated_inputs
.
append
(
variabl
e
)
# no more actions are needed
return
truncated_inputs
blockers
:
Set
[
Variable
]
=
set
(
ancestors_to_include
)
# enforce O(1) check for node in ancestors to include
blockers
:
set
[
Variable
]
=
set
(
ancestors_to_include
)
# variables that go here are under check already, do not repeat the loop for them
seen
:
set
[
Variable
]
=
set
()
# enforce O(1) check for variable in ancestors to include
ancestors_to_include
=
blockers
.
copy
()
while
candidates
:
# on any new candidate
node
=
candidates
.
pop
()
# There was a repeated reference to this node, we have already investigated it
if
node
in
truncated_inputs
:
variable
=
candidates
.
pop
()
# we've looked into this variable already
if
variable
in
seen
:
continue
# check if the node is independent, never go above blockers;
# blockers are independent nodes and ancestors to include
if
node
in
ancestors_to_include
:
# The case where node is in ancestors to include so we check if it depends on others
# check if the variable is independent, never go above blockers;
# blockers are independent variables and ancestors to include
elif
variable
in
ancestors_to_include
:
# The case where variable is in ancestors to include so we check if it depends on others
# it should be removed from the blockers to check against the rest
dependent
=
variable_depends_on
(
node
,
ancestors_to_include
-
{
nod
e
})
dependent
=
variable_depends_on
(
variable
,
ancestors_to_include
-
{
variabl
e
})
# ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs
truncated_inputs
.
append
(
nod
e
)
truncated_inputs
.
append
(
variabl
e
)
if
dependent
:
# if the ancestors to include is still dependent we need to go above, the search is not yet finished
# owner can never be None for a dependent
nod
e
candidates
.
extend
(
n
ode
.
owner
.
inputs
)
# owner can never be None for a dependent
variabl
e
candidates
.
extend
(
n
for
n
in
variable
.
owner
.
inputs
if
n
not
in
seen
)
else
:
# A regular
nod
e to check
dependent
=
variable_depends_on
(
nod
e
,
blockers
)
# all regular
nod
es fall to blockers
# A regular
variabl
e to check
dependent
=
variable_depends_on
(
variabl
e
,
blockers
)
# all regular
variabl
es fall to blockers
# 1. it is dependent - further search irrelevant
# 2. it is independent - the search
nod
e is inside the closure
blockers
.
add
(
nod
e
)
# if we've found an independent
nod
e and it is not in blockers so far
# it is a new independent
nod
e not present in ancestors to include
# 2. it is independent - the search
variabl
e is inside the closure
blockers
.
add
(
variabl
e
)
# if we've found an independent
variabl
e and it is not in blockers so far
# it is a new independent
variabl
e not present in ancestors to include
if
dependent
:
# populate search if it's not an independent
nod
e
# owner can never be None for a dependent
nod
e
candidates
.
extend
(
n
ode
.
owner
.
inputs
)
# populate search if it's not an independent
variabl
e
# owner can never be None for a dependent
variabl
e
candidates
.
extend
(
n
for
n
in
variable
.
owner
.
inputs
if
n
not
in
seen
)
else
:
# otherwise, do not search beyond
truncated_inputs
.
append
(
node
)
truncated_inputs
.
append
(
variable
)
# add variable to seen, no point in checking it once more
seen
.
add
(
variable
)
return
truncated_inputs
...
...
tests/graph/test_basic.py
浏览文件 @
e9a7d7ce
...
...
@@ -795,3 +795,18 @@ class TestTruncatedGraphInputs:
o2
.
name
=
"o2"
assert
truncated_graph_inputs
([
o2
],
[
trunc_inp
])
==
[
trunc_inp
,
x
]
def
test_single_pass_per_node
(
self
,
mocker
):
import
pytensor.graph.basic
inspect
=
mocker
.
spy
(
pytensor
.
graph
.
basic
,
"variable_depends_on"
)
x
=
at
.
dmatrix
(
"x"
)
m
=
x
.
shape
[
0
][
None
,
None
]
f
=
x
/
m
w
=
x
/
m
-
f
truncated_graph_inputs
([
w
],
[
x
])
# make sure there were exactly the same calls as unique variables seen by the function
assert
len
(
inspect
.
call_args_list
)
==
len
(
{
a
for
((
a
,
b
),
kw
)
in
inspect
.
call_args_list
}
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论