Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cdd8575b
提交
cdd8575b
authored
2月 01, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
5月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move Scan's as_while to ScanInfo
上级
81a8741c
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
34 行增加
和
55 行删除
+34
-55
scan.py
aesara/link/numba/dispatch/scan.py
+1
-1
basic.py
aesara/scan/basic.py
+1
-1
op.py
aesara/scan/op.py
+21
-29
opt.py
aesara/scan/opt.py
+7
-20
utils.py
aesara/scan/utils.py
+4
-4
没有找到文件。
aesara/link/numba/dispatch/scan.py
浏览文件 @
cdd8575b
...
@@ -120,7 +120,7 @@ def numba_funcify_Scan(op, node, **kwargs):
...
@@ -120,7 +120,7 @@ def numba_funcify_Scan(op, node, **kwargs):
]
]
while_logic
=
""
while_logic
=
""
if
op
.
as_while
:
if
op
.
info
.
as_while
:
# The inner function will be returning a boolean as last argument
# The inner function will be returning a boolean as last argument
inner_out_indexed
.
append
(
"while_flag"
)
inner_out_indexed
.
append
(
"while_flag"
)
while_logic
+=
"""
while_logic
+=
"""
...
...
aesara/scan/basic.py
浏览文件 @
cdd8575b
...
@@ -1140,6 +1140,7 @@ def scan(
...
@@ -1140,6 +1140,7 @@ def scan(
n_shared_outs
=
n_shared_outs
,
n_shared_outs
=
n_shared_outs
,
n_nit_sot
=
n_nit_sot
,
n_nit_sot
=
n_nit_sot
,
n_non_seqs
=
len
(
other_shared_inner_args
)
+
len
(
other_inner_args
),
n_non_seqs
=
len
(
other_shared_inner_args
)
+
len
(
other_inner_args
),
as_while
=
as_while
,
)
)
local_op
=
Scan
(
local_op
=
Scan
(
...
@@ -1149,7 +1150,6 @@ def scan(
...
@@ -1149,7 +1150,6 @@ def scan(
mode
=
mode
,
mode
=
mode
,
truncate_gradient
=
truncate_gradient
,
truncate_gradient
=
truncate_gradient
,
name
=
name
,
name
=
name
,
as_while
=
as_while
,
profile
=
profile
,
profile
=
profile
,
allow_gc
=
allow_gc
,
allow_gc
=
allow_gc
,
strict
=
strict
,
strict
=
strict
,
...
...
aesara/scan/op.py
浏览文件 @
cdd8575b
...
@@ -217,6 +217,7 @@ class ScanInfo:
...
@@ -217,6 +217,7 @@ class ScanInfo:
n_shared_outs
:
int
n_shared_outs
:
int
n_nit_sot
:
int
n_nit_sot
:
int
n_non_seqs
:
int
n_non_seqs
:
int
as_while
:
bool
TensorConstructorType
=
Callable
[[
List
[
bool
],
Union
[
str
,
np
.
generic
]],
TensorType
]
TensorConstructorType
=
Callable
[[
List
[
bool
],
Union
[
str
,
np
.
generic
]],
TensorType
]
...
@@ -670,8 +671,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -670,8 +671,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
as well as profiles for the computation of one step of each instance of
as well as profiles for the computation of one step of each instance of
`Scan`. The `name` of the instance appears in those profiles and can
`Scan`. The `name` of the instance appears in those profiles and can
greatly help to disambiguate information.
greatly help to disambiguate information.
as_while
Whether or not the `Scan` is a ``while``-loop.
profile
profile
If ``True`` or a non-empty string, a profile object will be created and
If ``True`` or a non-empty string, a profile object will be created and
attached to the inner graph of `Scan`. When `profile` is ``True``, the
attached to the inner graph of `Scan`. When `profile` is ``True``, the
...
@@ -701,7 +700,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -701,7 +700,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self
.
info
=
info
self
.
info
=
info
self
.
truncate_gradient
=
truncate_gradient
self
.
truncate_gradient
=
truncate_gradient
self
.
name
=
name
self
.
name
=
name
self
.
as_while
=
as_while
self
.
profile
=
profile
self
.
profile
=
profile
self
.
allow_gc
=
allow_gc
self
.
allow_gc
=
allow_gc
self
.
strict
=
strict
self
.
strict
=
strict
...
@@ -753,7 +751,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -753,7 +751,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for
o
in
outputs
[
end
:]:
for
o
in
outputs
[
end
:]:
self
.
output_types
.
append
(
o
.
type
)
self
.
output_types
.
append
(
o
.
type
)
if
self
.
as_while
:
if
info
.
as_while
:
self
.
output_types
=
self
.
output_types
[:
-
1
]
self
.
output_types
=
self
.
output_types
[:
-
1
]
if
not
hasattr
(
self
,
"name"
)
or
self
.
name
is
None
:
if
not
hasattr
(
self
,
"name"
)
or
self
.
name
is
None
:
...
@@ -1201,9 +1199,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1201,9 +1199,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if
self
.
info
!=
other
.
info
:
if
self
.
info
!=
other
.
info
:
return
False
return
False
if
self
.
as_while
!=
other
.
as_while
:
return
False
if
self
.
profile
!=
other
.
profile
:
if
self
.
profile
!=
other
.
profile
:
return
False
return
False
...
@@ -1234,7 +1229,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1234,7 +1229,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def
__str__
(
self
):
def
__str__
(
self
):
device_str
=
"cpu"
device_str
=
"cpu"
if
self
.
as_while
:
if
self
.
info
.
as_while
:
name
=
"do_while"
name
=
"do_while"
else
:
else
:
name
=
"for"
name
=
"for"
...
@@ -1261,7 +1256,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1261,7 +1256,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
type
(
self
),
type
(
self
),
self
.
_hash_inner_graph
,
self
.
_hash_inner_graph
,
self
.
info
,
self
.
info
,
self
.
as_while
,
self
.
profile
,
self
.
profile
,
self
.
truncate_gradient
,
self
.
truncate_gradient
,
self
.
name
,
self
.
name
,
...
@@ -1510,7 +1504,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1510,7 +1504,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self
.
info
.
n_mit_sot
,
self
.
info
.
n_mit_sot
,
self
.
info
.
n_sit_sot
,
self
.
info
.
n_sit_sot
,
self
.
info
.
n_nit_sot
,
self
.
info
.
n_nit_sot
,
self
.
as_while
,
self
.
info
.
as_while
,
cython_mintaps
,
cython_mintaps
,
self
.
info
.
tap_array
,
self
.
info
.
tap_array
,
tap_array_len
,
tap_array_len
,
...
@@ -1777,7 +1771,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1777,7 +1771,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
inner_output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
# 4.4. If there is a condition add it to the mix
# 4.4. If there is a condition add it to the mix
if
self
.
as_while
:
if
info
.
as_while
:
pdx
=
offset
+
info
.
n_shared_outs
pdx
=
offset
+
info
.
n_shared_outs
inner_output_storage
[
pdx
]
.
storage
[
0
]
=
None
inner_output_storage
[
pdx
]
.
storage
[
0
]
=
None
...
@@ -1847,7 +1841,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1847,7 +1841,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
raise
raise
dt_fn
=
time
.
time
()
-
t0_fn
dt_fn
=
time
.
time
()
-
t0_fn
if
self
.
as_while
:
if
info
.
as_while
:
pdx
=
offset
+
info
.
n_shared_outs
pdx
=
offset
+
info
.
n_shared_outs
cond
=
inner_output_storage
[
pdx
]
.
storage
[
0
]
==
0
cond
=
inner_output_storage
[
pdx
]
.
storage
[
0
]
==
0
...
@@ -2173,7 +2167,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2173,7 +2167,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for
in_ns
,
out_ns
in
zip
(
inner_non_sequences
,
node
.
inputs
[
offset
:]):
for
in_ns
,
out_ns
in
zip
(
inner_non_sequences
,
node
.
inputs
[
offset
:]):
out_equivalent
[
in_ns
]
=
out_ns
out_equivalent
[
in_ns
]
=
out_ns
if
self
.
as_while
:
if
info
.
as_while
:
self_outs
=
self
.
outputs
[:
-
1
]
self_outs
=
self
.
outputs
[:
-
1
]
else
:
else
:
self_outs
=
self
.
outputs
self_outs
=
self
.
outputs
...
@@ -2222,7 +2216,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2222,7 +2216,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
scan_outs
+=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
info
.
n_shared_outs
]]
scan_outs
+=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
info
.
n_shared_outs
]]
# if we are dealing with a repeat-until, then we do not know the
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
# leading dimension so we replace it for every entry with Shape_i
if
self
.
as_while
:
if
info
.
as_while
:
scan_outs_init
=
scan_outs
scan_outs_init
=
scan_outs
scan_outs
=
[]
scan_outs
=
[]
for
o
,
x
in
zip
(
node
.
outputs
,
scan_outs_init
):
for
o
,
x
in
zip
(
node
.
outputs
,
scan_outs_init
):
...
@@ -2312,7 +2306,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2312,7 +2306,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
)
else
:
else
:
grad_steps
=
inputs
[
0
]
grad_steps
=
inputs
[
0
]
if
self
.
as_while
:
if
info
.
as_while
:
n_steps
=
outs
[
0
]
.
shape
[
0
]
n_steps
=
outs
[
0
]
.
shape
[
0
]
# Restrict the number of grad steps according to
# Restrict the number of grad steps according to
...
@@ -2537,9 +2531,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2537,9 +2531,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dC_dinps_t
[
dx
+
info
.
n_seqs
]
=
dC_dXtm1
dC_dinps_t
[
dx
+
info
.
n_seqs
]
=
dC_dXtm1
else
:
else
:
dC_dinps_t
[
dx
+
info
.
n_seqs
]
+=
dC_dXtm1
dC_dinps_t
[
dx
+
info
.
n_seqs
]
+=
dC_dXtm1
# Construct scan op
# Seqs
if
info
.
as_while
:
if
self
.
as_while
:
# equivalent to x[:n_steps][::-1]
# equivalent to x[:n_steps][::-1]
outer_inp_seqs
=
[
x
[
n_steps
-
1
::
-
1
]
for
x
in
inputs
[
1
:
1
+
info
.
n_seqs
]]
outer_inp_seqs
=
[
x
[
n_steps
-
1
::
-
1
]
for
x
in
inputs
[
1
:
1
+
info
.
n_seqs
]]
else
:
else
:
...
@@ -2560,7 +2553,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2560,7 +2553,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
outer_inp_seqs
+=
[
x
[:
-
1
][::
-
1
]
for
x
in
self
.
outer_sitsot_outs
(
outs
)]
outer_inp_seqs
+=
[
x
[:
-
1
][::
-
1
]
for
x
in
self
.
outer_sitsot_outs
(
outs
)]
for
x
in
self
.
outer_nitsot_outs
(
dC_douts
):
for
x
in
self
.
outer_nitsot_outs
(
dC_douts
):
if
not
isinstance
(
x
.
type
,
DisconnectedType
):
if
not
isinstance
(
x
.
type
,
DisconnectedType
):
if
self
.
as_while
:
if
info
.
as_while
:
# equivalent to x[:n_steps][::-1]
# equivalent to x[:n_steps][::-1]
outer_inp_seqs
.
append
(
x
[
n_steps
-
1
::
-
1
])
outer_inp_seqs
.
append
(
x
[
n_steps
-
1
::
-
1
])
else
:
else
:
...
@@ -2572,7 +2565,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2572,7 +2565,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# fct add and we want to keep it for all Scan op. This is
# fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test
# used in T_Scan.test_grad_multiple_outs_taps to test
# that.
# that.
if
self
.
as_while
:
if
info
.
as_while
:
n
=
n_steps
.
tag
.
test_value
n
=
n_steps
.
tag
.
test_value
else
:
else
:
n
=
inputs
[
0
]
.
tag
.
test_value
n
=
inputs
[
0
]
.
tag
.
test_value
...
@@ -2585,7 +2578,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2585,7 +2578,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
assert
x
[::
-
1
][:
-
1
]
.
tag
.
test_value
.
shape
[
0
]
==
n
assert
x
[::
-
1
][:
-
1
]
.
tag
.
test_value
.
shape
[
0
]
==
n
for
x
in
self
.
outer_nitsot_outs
(
outs
):
for
x
in
self
.
outer_nitsot_outs
(
outs
):
if
hasattr
(
x
[::
-
1
]
.
tag
,
"test_value"
):
if
hasattr
(
x
[::
-
1
]
.
tag
,
"test_value"
):
if
self
.
as_while
:
if
info
.
as_while
:
assert
x
[
n_steps
-
1
::
-
1
]
.
tag
.
test_value
.
shape
[
0
]
==
n
assert
x
[
n_steps
-
1
::
-
1
]
.
tag
.
test_value
.
shape
[
0
]
==
n
else
:
else
:
assert
x
[::
-
1
]
.
tag
.
test_value
.
shape
[
0
]
==
n
assert
x
[::
-
1
]
.
tag
.
test_value
.
shape
[
0
]
==
n
...
@@ -2874,7 +2867,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2874,7 +2867,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+
outer_inp_seqs
+
outer_inp_seqs
+
outer_inp_mitmot
+
outer_inp_mitmot
+
outer_inp_sitsot
+
outer_inp_sitsot
+
[
n_steps
if
self
.
as_while
else
inputs
[
0
]
for
_
in
range
(
n_nit_sot
)]
+
[
n_steps
if
info
.
as_while
else
inputs
[
0
]
for
_
in
range
(
n_nit_sot
)]
+
self
.
outer_shared
(
inputs
)
+
self
.
outer_shared
(
inputs
)
+
self
.
outer_non_seqs
(
inputs
)
+
self
.
outer_non_seqs
(
inputs
)
)
)
...
@@ -2900,6 +2893,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2900,6 +2893,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_nit_sot
=
n_nit_sot
,
n_nit_sot
=
n_nit_sot
,
n_non_seqs
=
len
(
self
.
outer_shared
(
inputs
))
n_non_seqs
=
len
(
self
.
outer_shared
(
inputs
))
+
len
(
self
.
outer_non_seqs
(
inputs
)),
+
len
(
self
.
outer_non_seqs
(
inputs
)),
as_while
=
False
,
)
)
local_op
=
Scan
(
local_op
=
Scan
(
...
@@ -2908,7 +2902,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2908,7 +2902,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
out_info
,
out_info
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
truncate_gradient
=
self
.
truncate_gradient
,
truncate_gradient
=
self
.
truncate_gradient
,
as_while
=
False
,
profile
=
self
.
profile
,
profile
=
self
.
profile
,
name
=
f
"grad_of_{self.name}"
if
self
.
name
else
None
,
name
=
f
"grad_of_{self.name}"
if
self
.
name
else
None
,
allow_gc
=
self
.
allow_gc
,
allow_gc
=
self
.
allow_gc
,
...
@@ -2930,7 +2923,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2930,7 +2923,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# If the forward scan is in as_while mode, we need to pad
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# the gradients, so that they match the size of the input
# sequences.
# sequences.
if
self
.
as_while
:
if
info
.
as_while
:
n_zeros
=
inputs
[
0
]
-
n_steps
n_zeros
=
inputs
[
0
]
-
n_steps
shp
=
(
n_zeros
,)
shp
=
(
n_zeros
,)
if
x
.
ndim
>
1
:
if
x
.
ndim
>
1
:
...
@@ -2958,7 +2951,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2958,7 +2951,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# If the forward scan is in as_while mode, we need to pad
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# the gradients, so that they match the size of the input
# sequences.
# sequences.
if
self
.
as_while
:
if
info
.
as_while
:
n_zeros
=
inputs
[
0
]
-
grad_steps
n_zeros
=
inputs
[
0
]
-
grad_steps
shp
=
(
n_zeros
,)
shp
=
(
n_zeros
,)
if
x
.
ndim
>
1
:
if
x
.
ndim
>
1
:
...
@@ -3052,7 +3045,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3052,7 +3045,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Step 1. Compute the R_op of the inner function
# Step 1. Compute the R_op of the inner function
inner_eval_points
=
[
safe_new
(
x
,
"_evalpoint"
)
for
x
in
rop_of_inputs
]
inner_eval_points
=
[
safe_new
(
x
,
"_evalpoint"
)
for
x
in
rop_of_inputs
]
if
self
.
as_while
:
if
info
.
as_while
:
rop_self_outputs
=
self_outputs
[:
-
1
]
rop_self_outputs
=
self_outputs
[:
-
1
]
else
:
else
:
rop_self_outputs
=
self_outputs
rop_self_outputs
=
self_outputs
...
@@ -3209,7 +3202,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3209,7 +3202,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+
inner_out_shared
+
inner_out_shared
)
)
if
self
.
as_while
:
if
info
.
as_while
:
inner_outs
+=
[
self_outputs
[
-
1
]]
inner_outs
+=
[
self_outputs
[
-
1
]]
scan_inputs
=
(
scan_inputs
=
(
[
inputs
[
0
]]
[
inputs
[
0
]]
...
@@ -3233,6 +3226,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3233,6 +3226,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
tap_array
=
tuple
(
tuple
(
v
)
for
v
in
new_tap_array
),
tap_array
=
tuple
(
tuple
(
v
)
for
v
in
new_tap_array
),
mit_mot_out_slices
=
tuple
(
tuple
(
v
)
for
v
in
info
.
mit_mot_out_slices
)
*
2
,
mit_mot_out_slices
=
tuple
(
tuple
(
v
)
for
v
in
info
.
mit_mot_out_slices
)
*
2
,
n_non_seqs
=
len
(
inner_other
),
n_non_seqs
=
len
(
inner_other
),
as_while
=
info
.
as_while
,
)
)
local_op
=
Scan
(
local_op
=
Scan
(
...
@@ -3240,7 +3234,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3240,7 +3234,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_outs
,
inner_outs
,
out_info
,
out_info
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
as_while
=
self
.
as_while
,
profile
=
self
.
profile
,
profile
=
self
.
profile
,
truncate_gradient
=
self
.
truncate_gradient
,
truncate_gradient
=
self
.
truncate_gradient
,
name
=
f
"rop_of_{self.name}"
if
self
.
name
else
None
,
name
=
f
"rop_of_{self.name}"
if
self
.
name
else
None
,
...
@@ -3363,7 +3356,6 @@ def _op_debug_information_Scan(op, node):
...
@@ -3363,7 +3356,6 @@ def _op_debug_information_Scan(op, node):
inner_inputs
,
inner_inputs
,
inner_outputs
,
inner_outputs
,
node
.
op
.
info
,
node
.
op
.
info
,
node
.
op
.
as_while
,
clone
=
False
,
clone
=
False
,
)
)
...
...
aesara/scan/opt.py
浏览文件 @
cdd8575b
...
@@ -176,7 +176,6 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
...
@@ -176,7 +176,6 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
op_outs
,
op_outs
,
nw_info
,
nw_info
,
mode
=
op
.
mode
,
mode
=
op
.
mode
,
as_while
=
op
.
as_while
,
profile
=
op
.
profile
,
profile
=
op
.
profile
,
truncate_gradient
=
op
.
truncate_gradient
,
truncate_gradient
=
op
.
truncate_gradient
,
# TODO: This seems questionable
# TODO: This seems questionable
...
@@ -351,7 +350,6 @@ def push_out_non_seq_scan(fgraph, node):
...
@@ -351,7 +350,6 @@ def push_out_non_seq_scan(fgraph, node):
op_outs
,
op_outs
,
new_info
,
new_info
,
mode
=
op
.
mode
,
mode
=
op
.
mode
,
as_while
=
op
.
as_while
,
profile
=
op
.
profile
,
profile
=
op
.
profile
,
truncate_gradient
=
op
.
truncate_gradient
,
truncate_gradient
=
op
.
truncate_gradient
,
# TODO: This seems questionable
# TODO: This seems questionable
...
@@ -589,7 +587,6 @@ def push_out_seq_scan(fgraph, node):
...
@@ -589,7 +587,6 @@ def push_out_seq_scan(fgraph, node):
op_outs
,
op_outs
,
nw_info
,
nw_info
,
mode
=
op
.
mode
,
mode
=
op
.
mode
,
as_while
=
op
.
as_while
,
profile
=
op
.
profile
,
profile
=
op
.
profile
,
truncate_gradient
=
op
.
truncate_gradient
,
truncate_gradient
=
op
.
truncate_gradient
,
# TODO: This seems questionable
# TODO: This seems questionable
...
@@ -606,7 +603,7 @@ def push_out_seq_scan(fgraph, node):
...
@@ -606,7 +603,7 @@ def push_out_seq_scan(fgraph, node):
replacements
[
"remove"
]
=
[
node
]
replacements
[
"remove"
]
=
[
node
]
return
replacements
return
replacements
elif
not
to_keep_set
and
not
op
.
as_while
and
not
op
.
outer_mitmot
(
node
.
inputs
):
elif
not
to_keep_set
and
not
op
.
info
.
as_while
and
not
op
.
outer_mitmot
(
node
.
inputs
):
# Nothing in the inner graph should be kept
# Nothing in the inner graph should be kept
replace_with
=
{}
replace_with
=
{}
for
out
,
idx
in
to_replace_map
.
items
():
for
out
,
idx
in
to_replace_map
.
items
():
...
@@ -728,7 +725,6 @@ def push_out_inner_vars(
...
@@ -728,7 +725,6 @@ def push_out_inner_vars(
new_scan_node
.
op
.
inputs
,
new_scan_node
.
op
.
inputs
,
new_scan_node
.
op
.
outputs
,
new_scan_node
.
op
.
outputs
,
new_scan_node
.
op
.
info
,
new_scan_node
.
op
.
info
,
new_scan_node
.
op
.
as_while
,
)
)
new_outs
=
new_scan_args
.
outer_out_nit_sot
[
-
len
(
add_as_nitsots
)
:]
new_outs
=
new_scan_args
.
outer_out_nit_sot
[
-
len
(
add_as_nitsots
)
:]
...
@@ -770,7 +766,6 @@ def add_nitsot_outputs(
...
@@ -770,7 +766,6 @@ def add_nitsot_outputs(
new_scan_args
.
inner_outputs
,
new_scan_args
.
inner_outputs
,
new_scan_args
.
info
,
new_scan_args
.
info
,
mode
=
old_scan_node
.
op
.
mode
,
mode
=
old_scan_node
.
op
.
mode
,
as_while
=
old_scan_node
.
op
.
as_while
,
profile
=
old_scan_node
.
op
.
profile
,
profile
=
old_scan_node
.
op
.
profile
,
truncate_gradient
=
old_scan_node
.
op
.
truncate_gradient
,
truncate_gradient
=
old_scan_node
.
op
.
truncate_gradient
,
# TODO: This seems questionable
# TODO: This seems questionable
...
@@ -818,16 +813,14 @@ def push_out_add_scan(fgraph, node):
...
@@ -818,16 +813,14 @@ def push_out_add_scan(fgraph, node):
# Don't perform the optimization on `as_while` `Scan`s. Because these
# Don't perform the optimization on `as_while` `Scan`s. Because these
# `Scan`s don't run for a predetermined number of steps, handling them is
# `Scan`s don't run for a predetermined number of steps, handling them is
# more complicated and this optimization doesn't support it at the moment.
# more complicated and this optimization doesn't support it at the moment.
if
not
(
isinstance
(
node
.
op
,
Scan
)
and
not
node
.
op
.
as_while
):
if
not
(
isinstance
(
node
.
op
,
Scan
)
and
not
node
.
op
.
info
.
as_while
):
return
False
return
False
op
=
node
.
op
op
=
node
.
op
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
# use
args
=
ScanArgs
(
args
=
ScanArgs
(
node
.
inputs
,
node
.
outputs
,
op
.
inputs
,
op
.
outputs
,
op
.
info
)
node
.
inputs
,
node
.
outputs
,
op
.
inputs
,
op
.
outputs
,
op
.
info
,
op
.
as_while
)
clients
=
{}
clients
=
{}
local_fgraph_topo
=
io_toposort
(
local_fgraph_topo
=
io_toposort
(
...
@@ -997,7 +990,6 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -997,7 +990,6 @@ class ScanInplaceOptimizer(GlobalOptimizer):
op
.
info
,
op
.
info
,
mode
=
op
.
mode
,
mode
=
op
.
mode
,
typeConstructor
=
typeConstructor
,
typeConstructor
=
typeConstructor
,
as_while
=
op
.
as_while
,
profile
=
op
.
profile
,
profile
=
op
.
profile
,
truncate_gradient
=
op
.
truncate_gradient
,
truncate_gradient
=
op
.
truncate_gradient
,
# TODO: This seems questionable
# TODO: This seems questionable
...
@@ -1525,7 +1517,6 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1525,7 +1517,6 @@ def save_mem_new_scan(fgraph, node):
outs
,
outs
,
info
,
info
,
mode
=
op
.
mode
,
mode
=
op
.
mode
,
as_while
=
op
.
as_while
,
profile
=
op
.
profile
,
profile
=
op
.
profile
,
truncate_gradient
=
op
.
truncate_gradient
,
truncate_gradient
=
op
.
truncate_gradient
,
# TODO: This seems questionable
# TODO: This seems questionable
...
@@ -1662,7 +1653,7 @@ class ScanMerge(GlobalOptimizer):
...
@@ -1662,7 +1653,7 @@ class ScanMerge(GlobalOptimizer):
def
merge
(
self
,
nodes
):
def
merge
(
self
,
nodes
):
if
nodes
[
0
]
.
op
.
as_while
:
if
nodes
[
0
]
.
op
.
info
.
as_while
:
as_while
=
True
as_while
=
True
condition
=
nodes
[
0
]
.
op
.
outputs
[
-
1
]
condition
=
nodes
[
0
]
.
op
.
outputs
[
-
1
]
else
:
else
:
...
@@ -1813,6 +1804,7 @@ class ScanMerge(GlobalOptimizer):
...
@@ -1813,6 +1804,7 @@ class ScanMerge(GlobalOptimizer):
n_shared_outs
=
sum
(
nd
.
op
.
n_shared_outs
for
nd
in
nodes
),
n_shared_outs
=
sum
(
nd
.
op
.
n_shared_outs
for
nd
in
nodes
),
n_nit_sot
=
sum
(
nd
.
op
.
n_nit_sot
for
nd
in
nodes
),
n_nit_sot
=
sum
(
nd
.
op
.
n_nit_sot
for
nd
in
nodes
),
n_non_seqs
=
n_non_seqs
,
n_non_seqs
=
n_non_seqs
,
as_while
=
as_while
,
)
)
old_op
=
nodes
[
0
]
.
op
old_op
=
nodes
[
0
]
.
op
...
@@ -1825,7 +1817,6 @@ class ScanMerge(GlobalOptimizer):
...
@@ -1825,7 +1817,6 @@ class ScanMerge(GlobalOptimizer):
truncate_gradient
=
old_op
.
truncate_gradient
,
truncate_gradient
=
old_op
.
truncate_gradient
,
allow_gc
=
old_op
.
allow_gc
,
allow_gc
=
old_op
.
allow_gc
,
name
=
"&"
.
join
([
nd
.
op
.
name
for
nd
in
nodes
]),
name
=
"&"
.
join
([
nd
.
op
.
name
for
nd
in
nodes
]),
as_while
=
as_while
,
)
)
new_outs
=
new_op
(
*
outer_ins
)
new_outs
=
new_op
(
*
outer_ins
)
...
@@ -1846,7 +1837,7 @@ class ScanMerge(GlobalOptimizer):
...
@@ -1846,7 +1837,7 @@ class ScanMerge(GlobalOptimizer):
"""
"""
rep
=
set_nodes
[
0
]
rep
=
set_nodes
[
0
]
if
(
if
(
rep
.
op
.
as_while
!=
node
.
op
.
as_while
rep
.
op
.
info
.
as_while
!=
node
.
op
.
info
.
as_while
or
node
.
op
.
truncate_gradient
!=
rep
.
op
.
truncate_gradient
or
node
.
op
.
truncate_gradient
!=
rep
.
op
.
truncate_gradient
or
node
.
op
.
mode
!=
rep
.
op
.
mode
or
node
.
op
.
mode
!=
rep
.
op
.
mode
):
):
...
@@ -1872,7 +1863,7 @@ class ScanMerge(GlobalOptimizer):
...
@@ -1872,7 +1863,7 @@ class ScanMerge(GlobalOptimizer):
if
is_in_ancestors
(
node
,
nd
)
or
is_in_ancestors
(
nd
,
node
):
if
is_in_ancestors
(
node
,
nd
)
or
is_in_ancestors
(
nd
,
node
):
return
False
return
False
if
not
node
.
op
.
as_while
:
if
not
node
.
op
.
info
.
as_while
:
return
True
return
True
cond
=
node
.
op
.
outputs
[
-
1
]
cond
=
node
.
op
.
outputs
[
-
1
]
rep_cond
=
rep
.
op
.
outputs
[
-
1
]
rep_cond
=
rep
.
op
.
outputs
[
-
1
]
...
@@ -1957,7 +1948,6 @@ def scan_merge_inouts(fgraph, node):
...
@@ -1957,7 +1948,6 @@ def scan_merge_inouts(fgraph, node):
node
.
op
.
inputs
,
node
.
op
.
inputs
,
node
.
op
.
outputs
,
node
.
op
.
outputs
,
node
.
op
.
info
,
node
.
op
.
info
,
node
.
op
.
as_while
,
)
)
inp_equiv
=
{}
inp_equiv
=
{}
...
@@ -2001,7 +1991,6 @@ def scan_merge_inouts(fgraph, node):
...
@@ -2001,7 +1991,6 @@ def scan_merge_inouts(fgraph, node):
inner_outputs
,
inner_outputs
,
info
,
info
,
mode
=
node
.
op
.
mode
,
mode
=
node
.
op
.
mode
,
as_while
=
node
.
op
.
as_while
,
profile
=
node
.
op
.
profile
,
profile
=
node
.
op
.
profile
,
truncate_gradient
=
node
.
op
.
truncate_gradient
,
truncate_gradient
=
node
.
op
.
truncate_gradient
,
# TODO: This seems questionable
# TODO: This seems questionable
...
@@ -2019,7 +2008,6 @@ def scan_merge_inouts(fgraph, node):
...
@@ -2019,7 +2008,6 @@ def scan_merge_inouts(fgraph, node):
new_op
.
inputs
,
new_op
.
inputs
,
new_op
.
outputs
,
new_op
.
outputs
,
new_op
.
info
,
new_op
.
info
,
new_op
.
as_while
,
)
)
remove
=
[
node
]
remove
=
[
node
]
else
:
else
:
...
@@ -2266,7 +2254,6 @@ def push_out_dot1_scan(fgraph, node):
...
@@ -2266,7 +2254,6 @@ def push_out_dot1_scan(fgraph, node):
new_inner_outs
,
new_inner_outs
,
new_info
,
new_info
,
mode
=
op
.
mode
,
mode
=
op
.
mode
,
as_while
=
op
.
as_while
,
profile
=
op
.
profile
,
profile
=
op
.
profile
,
truncate_gradient
=
op
.
truncate_gradient
,
truncate_gradient
=
op
.
truncate_gradient
,
# TODO: This seems questionable
# TODO: This seems questionable
...
...
aesara/scan/utils.py
浏览文件 @
cdd8575b
...
@@ -408,6 +408,7 @@ def compress_outs(op, not_required, inputs):
...
@@ -408,6 +408,7 @@ def compress_outs(op, not_required, inputs):
n_shared_outs
=
0
,
n_shared_outs
=
0
,
n_nit_sot
=
0
,
n_nit_sot
=
0
,
n_non_seqs
=
0
,
n_non_seqs
=
0
,
as_while
=
op
.
info
.
as_while
,
)
)
op_inputs
=
op
.
inputs
[:
op
.
n_seqs
]
op_inputs
=
op
.
inputs
[:
op
.
n_seqs
]
...
@@ -528,7 +529,7 @@ def compress_outs(op, not_required, inputs):
...
@@ -528,7 +529,7 @@ def compress_outs(op, not_required, inputs):
op_inputs
+=
op
.
inputs
[
i_offset
:]
op_inputs
+=
op
.
inputs
[
i_offset
:]
info
=
dataclasses
.
replace
(
info
,
n_non_seqs
=
len
(
op
.
inputs
[
i_offset
:]))
info
=
dataclasses
.
replace
(
info
,
n_non_seqs
=
len
(
op
.
inputs
[
i_offset
:]))
node_inputs
+=
inputs
[
ni_offset
+
op
.
n_shared_outs
+
op
.
n_nit_sot
:]
node_inputs
+=
inputs
[
ni_offset
+
op
.
n_shared_outs
+
op
.
n_nit_sot
:]
if
op
.
as_while
:
if
op
.
info
.
as_while
:
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
map_old_new
[
o_offset
]
=
len
(
op_outputs
)
-
1
map_old_new
[
o_offset
]
=
len
(
op_outputs
)
-
1
# map_old_new[len(op_outputs)-1] = o_offset
# map_old_new[len(op_outputs)-1] = o_offset
...
@@ -582,11 +583,10 @@ class ScanArgs:
...
@@ -582,11 +583,10 @@ class ScanArgs:
_inner_inputs
:
Sequence
[
Variable
],
_inner_inputs
:
Sequence
[
Variable
],
_inner_outputs
:
Sequence
[
Variable
],
_inner_outputs
:
Sequence
[
Variable
],
info
:
"ScanInfo"
,
info
:
"ScanInfo"
,
as_while
:
bool
,
clone
:
Optional
[
bool
]
=
True
,
clone
:
Optional
[
bool
]
=
True
,
):
):
self
.
n_steps
=
outer_inputs
[
0
]
self
.
n_steps
=
outer_inputs
[
0
]
self
.
as_while
=
as_while
self
.
as_while
=
info
.
as_while
if
clone
:
if
clone
:
rval
=
reconstruct_graph
(
_inner_inputs
,
_inner_outputs
,
""
)
rval
=
reconstruct_graph
(
_inner_inputs
,
_inner_outputs
,
""
)
...
@@ -710,7 +710,6 @@ class ScanArgs:
...
@@ -710,7 +710,6 @@ class ScanArgs:
node
.
op
.
inputs
,
node
.
op
.
inputs
,
node
.
op
.
outputs
,
node
.
op
.
outputs
,
node
.
op
.
info
,
node
.
op
.
info
,
node
.
op
.
as_while
,
clone
=
clone
,
clone
=
clone
,
)
)
...
@@ -815,6 +814,7 @@ class ScanArgs:
...
@@ -815,6 +814,7 @@ class ScanArgs:
n_mit_mot_outs
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_out_slices
),
n_mit_mot_outs
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_out_slices
),
mit_mot_out_slices
=
tuple
(
self
.
mit_mot_out_slices
),
mit_mot_out_slices
=
tuple
(
self
.
mit_mot_out_slices
),
n_non_seqs
=
len
(
self
.
inner_in_non_seqs
),
n_non_seqs
=
len
(
self
.
inner_in_non_seqs
),
as_while
=
self
.
as_while
,
)
)
def
get_alt_field
(
def
get_alt_field
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论