Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e85c7fd0
提交
e85c7fd0
authored
8月 08, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
9月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace Scan info dict with ScanInfo dataclass
上级
d83ff33c
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
35 行增加
和
29 行删除
+35
-29
basic.py
aesara/scan/basic.py
+25
-24
op.py
aesara/scan/op.py
+0
-0
opt.py
aesara/scan/opt.py
+0
-0
utils.py
aesara/scan/utils.py
+0
-0
requirements.txt
requirements.txt
+1
-0
setup.py
setup.py
+8
-1
test_utils.py
tests/scan/test_utils.py
+1
-4
没有找到文件。
aesara/scan/basic.py
浏览文件 @
e85c7fd0
...
@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError
...
@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError
from
aesara.graph.op
import
get_test_value
from
aesara.graph.op
import
get_test_value
from
aesara.graph.utils
import
TestValueError
from
aesara.graph.utils
import
TestValueError
from
aesara.scan
import
utils
from
aesara.scan
import
utils
from
aesara.scan.op
import
Scan
from
aesara.scan.op
import
Scan
,
ScanInfo
from
aesara.scan.utils
import
safe_new
,
traverse
from
aesara.scan.utils
import
safe_new
,
traverse
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
minimum
from
aesara.tensor.math
import
minimum
...
@@ -1022,31 +1022,32 @@ def scan(
...
@@ -1022,31 +1022,32 @@ def scan(
# Step 7. Create the Scan Op
# Step 7. Create the Scan Op
##
##
tap_array
=
mit_sot_tap_array
+
[[
-
1
]
for
x
in
range
(
n_sit_sot
)]
tap_array
=
tuple
(
tuple
(
v
)
for
v
in
mit_sot_tap_array
)
+
tuple
(
(
-
1
,)
for
x
in
range
(
n_sit_sot
)
)
if
allow_gc
is
None
:
if
allow_gc
is
None
:
allow_gc
=
config
.
scan__allow_gc
allow_gc
=
config
.
scan__allow_gc
info
=
OrderedDict
()
info
=
ScanInfo
(
info
[
"tap_array"
]
=
tap_array
tap_array
=
tap_array
,
info
[
"n_seqs"
]
=
n_seqs
n_seqs
=
n_seqs
,
info
[
"n_mit_mot"
]
=
n_mit_mot
n_mit_mot
=
n_mit_mot
,
info
[
"n_mit_mot_outs"
]
=
n_mit_mot_outs
n_mit_mot_outs
=
n_mit_mot_outs
,
info
[
"mit_mot_out_slices"
]
=
mit_mot_out_slices
mit_mot_out_slices
=
tuple
(
tuple
(
v
)
for
v
in
mit_mot_out_slices
),
info
[
"n_mit_sot"
]
=
n_mit_sot
n_mit_sot
=
n_mit_sot
,
info
[
"n_sit_sot"
]
=
n_sit_sot
n_sit_sot
=
n_sit_sot
,
info
[
"n_shared_outs"
]
=
n_shared_outs
n_shared_outs
=
n_shared_outs
,
info
[
"n_nit_sot"
]
=
n_nit_sot
n_nit_sot
=
n_nit_sot
,
info
[
"truncate_gradient"
]
=
truncate_gradient
truncate_gradient
=
truncate_gradient
,
info
[
"name"
]
=
name
name
=
name
,
info
[
"mode"
]
=
mode
gpua
=
False
,
info
[
"destroy_map"
]
=
OrderedDict
()
as_while
=
as_while
,
info
[
"gpua"
]
=
False
profile
=
profile
,
info
[
"as_while"
]
=
as_while
allow_gc
=
allow_gc
,
info
[
"profile"
]
=
profile
strict
=
strict
,
info
[
"allow_gc"
]
=
allow_gc
)
info
[
"strict"
]
=
strict
local_op
=
Scan
(
inner_inputs
,
new_outs
,
info
,
mode
)
local_op
=
Scan
(
inner_inputs
,
new_outs
,
info
)
##
##
# Step 8. Compute the outputs using the scan op
# Step 8. Compute the outputs using the scan op
...
...
aesara/scan/op.py
浏览文件 @
e85c7fd0
差异被折叠。
点击展开。
aesara/scan/opt.py
浏览文件 @
e85c7fd0
差异被折叠。
点击展开。
aesara/scan/utils.py
浏览文件 @
e85c7fd0
差异被折叠。
点击展开。
requirements.txt
浏览文件 @
e85c7fd0
-e ./
-e ./
dataclasses
>=0.7; python_version < '3.7'
filelock
filelock
flake8
==3.8.4
flake8
==3.8.4
pep8
pep8
...
...
setup.py
浏览文件 @
e85c7fd0
#!/usr/bin/env python
#!/usr/bin/env python
import
sys
from
setuptools
import
find_packages
,
setup
from
setuptools
import
find_packages
,
setup
import
versioneer
import
versioneer
...
@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9
...
@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9
"""
"""
CLASSIFIERS
=
[
_f
for
_f
in
CLASSIFIERS
.
split
(
"
\n
"
)
if
_f
]
CLASSIFIERS
=
[
_f
for
_f
in
CLASSIFIERS
.
split
(
"
\n
"
)
if
_f
]
install_requires
=
[
"numpy>=1.17.0"
,
"scipy>=0.14"
,
"filelock"
]
if
sys
.
version_info
[
0
:
2
]
<
(
3
,
7
):
install_requires
+=
[
"dataclasses"
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
setup
(
setup
(
name
=
NAME
,
name
=
NAME
,
...
@@ -57,7 +64,7 @@ if __name__ == "__main__":
...
@@ -57,7 +64,7 @@ if __name__ == "__main__":
license
=
LICENSE
,
license
=
LICENSE
,
platforms
=
PLATFORMS
,
platforms
=
PLATFORMS
,
packages
=
find_packages
(
exclude
=
[
"tests"
,
"tests.*"
]),
packages
=
find_packages
(
exclude
=
[
"tests"
,
"tests.*"
]),
install_requires
=
[
"numpy>=1.17.0"
,
"scipy>=0.14"
,
"filelock"
]
,
install_requires
=
install_requires
,
package_data
=
{
package_data
=
{
""
:
[
""
:
[
"*.txt"
,
"*.txt"
,
...
...
tests/scan/test_utils.py
浏览文件 @
e85c7fd0
...
@@ -252,10 +252,7 @@ def test_ScanArgs():
...
@@ -252,10 +252,7 @@ def test_ScanArgs():
# The `scan_args` base class always clones the inner-graph;
# The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same)
# here we make sure it doesn't (and that all the inputs are the same)
assert
scan_args
.
inputs
==
scan_op
.
inputs
assert
scan_args
.
inputs
==
scan_op
.
inputs
scan_op_info
=
dict
(
scan_op
.
info
)
assert
scan_args
.
info
==
scan_op
.
info
# The `ScanInfo` dictionary has the wrong order and an extra entry
del
scan_op_info
[
"strict"
]
assert
dict
(
scan_args
.
info
)
==
scan_op_info
assert
scan_args
.
var_mappings
==
scan_op
.
var_mappings
assert
scan_args
.
var_mappings
==
scan_op
.
var_mappings
# Check that `ScanArgs.find_among_fields` works
# Check that `ScanArgs.find_among_fields` works
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论