Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4aea87c2
提交
4aea87c2
authored
2月 14, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
2月 17, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix bug when taking the L_op of a Scan with mit-mot and disconnected output gradients
上级
49cf9d22
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
89 行增加
和
34 行删除
+89
-34
op.py
pytensor/scan/op.py
+38
-34
test_basic.py
tests/scan/test_basic.py
+51
-0
没有找到文件。
pytensor/scan/op.py
浏览文件 @
4aea87c2
...
@@ -2509,13 +2509,25 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2509,13 +2509,25 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return
rval
return
rval
var_mappings
=
self
.
get_oinp_iinp_iout_oout_mappings
()
var_mappings
=
self
.
get_oinp_iinp_iout_oout_mappings
()
dC_dinps_t
=
[
None
for
inp
in
diff_inputs
]
disconnected_dC_dinps_t
=
[
True
for
inp
in
diff_inputs
]
disconnected_dC_dinps_t
=
[
True
for
inp
in
diff_inputs
]
n_mit_mot_outs
=
info
.
n_mit_mot_outs
# In the case of mit-mot there can be more inner outputs than outer ones
n_extra_mit_mot_outs
=
n_mit_mot_outs
-
info
.
n_mit_mot
idx_nitsot_out_start
=
n_mit_mot_outs
+
info
.
n_mit_sot
+
info
.
n_sit_sot
idx_nitsot_out_end
=
idx_nitsot_out_start
+
info
.
n_nit_sot
# Create dummy variables for the internal input gradients
states
=
(
self
.
inner_mitmot
(
self_inputs
)
+
self
.
inner_mitsot
(
self_inputs
)
+
self
.
inner_sitsot
(
self_inputs
)
)
dC_dXts
=
[]
dC_dXts
=
[]
Xts
=
[]
Xts
=
[]
for
idx
,
Xt
in
enumerate
(
diff_outputs
):
for
idx
,
Xt
in
enumerate
(
diff_outputs
):
# We are looking for x[t-1] for a given x[t]
# We are looking for x[t-1] for a given x[t]
if
idx
>=
info
.
n_mit_mot_outs
:
if
idx
>=
n_mit_mot_outs
:
Xt_placeholder
=
safe_new
(
Xt
)
Xt_placeholder
=
safe_new
(
Xt
)
Xts
.
append
(
Xt_placeholder
)
Xts
.
append
(
Xt_placeholder
)
...
@@ -2523,9 +2535,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2523,9 +2535,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# or not. NOTE : This cannot be done by using
# or not. NOTE : This cannot be done by using
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# the exact same variable can be used as multiple outputs.
# the exact same variable can be used as multiple outputs.
idx_nitsot_start
=
info
.
n_mit_mot
+
info
.
n_mit_sot
+
info
.
n_sit_sot
if
idx
<
idx_nitsot_out_start
or
idx
>=
idx_nitsot_out_end
:
idx_nitsot_end
=
idx_nitsot_start
+
info
.
n_nit_sot
if
idx
<
idx_nitsot_start
or
idx
>=
idx_nitsot_end
:
# What we do here is loop through dC_douts and collect all
# What we do here is loop through dC_douts and collect all
# those that are connected to the specific one and do an
# those that are connected to the specific one and do an
# upcast on all of their dtypes to get the dtype for this
# upcast on all of their dtypes to get the dtype for this
...
@@ -2533,12 +2543,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2533,12 +2543,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# specific previous step is defined or not is done somewhere
# specific previous step is defined or not is done somewhere
# else.
# else.
dtypes
=
[]
dtypes
=
[]
states
=
(
self
.
inner_mitmot
(
self_inputs
)
+
self
.
inner_mitsot
(
self_inputs
)
+
self
.
inner_sitsot
(
self_inputs
)
)
for
pos
,
inp
in
enumerate
(
states
):
for
pos
,
inp
in
enumerate
(
states
):
if
inp
in
graph_inputs
([
Xt
]):
if
inp
in
graph_inputs
([
Xt
]):
# Get the index of the outer output that to which
# Get the index of the outer output that to which
...
@@ -2555,35 +2559,39 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2555,35 +2559,39 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
new_dtype
=
config
.
floatX
new_dtype
=
config
.
floatX
dC_dXt
=
safe_new
(
Xt
,
dtype
=
new_dtype
)
dC_dXt
=
safe_new
(
Xt
,
dtype
=
new_dtype
)
else
:
else
:
if
isinstance
(
dC_douts
[
idx
]
.
type
,
DisconnectedType
):
# nit-sot outputs
# If not disconnected assume the output gradient type is a valid type for the input gradient
if
isinstance
(
dC_douts
[
idx
-
n_extra_mit_mot_outs
]
.
type
,
DisconnectedType
):
continue
continue
dC_dXt
=
safe_new
(
dC_douts
[
idx
][
0
])
dC_dXt
=
safe_new
(
dC_douts
[
idx
-
n_extra_mit_mot_outs
][
0
])
dC_dXts
.
append
(
dC_dXt
)
dC_dXts
.
append
(
dC_dXt
)
# Handle cases where the very same variable may be used as different outputs
# TODO: Couldn't we add a view Op to avoid this when building the Scan graph?
known_grads
=
{}
known_grads
=
{}
dc_dxts_idx
=
0
dc_dxts_idx
=
0
for
i
in
range
(
len
(
diff_outputs
)):
for
i
in
range
(
len
(
diff_outputs
)):
if
i
<
idx_nitsot_start
or
i
>=
idx_nitsot_end
:
if
not
(
i
<
idx_nitsot_out_start
or
i
>=
idx_nitsot_out_end
)
and
isinstance
(
if
diff_outputs
[
i
]
in
known_grads
:
dC_douts
[
i
-
n_extra_mit_mot_outs
]
.
type
,
DisconnectedType
known_grads
[
diff_outputs
[
i
]]
+=
dC_dXts
[
dc_dxts_idx
]
):
else
:
# Special case where we don't have a dC_dXt for disconnected nitsot outputs
known_grads
[
diff_outputs
[
i
]]
=
dC_dXts
[
dc_dxts_idx
]
continue
dc_dxts_idx
+=
1
# Just some trouble to avoid a +0
if
diff_outputs
[
i
]
in
known_grads
:
known_grads
[
diff_outputs
[
i
]]
+=
dC_dXts
[
dc_dxts_idx
]
else
:
else
:
if
isinstance
(
dC_douts
[
i
]
.
type
,
DisconnectedType
):
known_grads
[
diff_outputs
[
i
]]
=
dC_dXts
[
dc_dxts_idx
]
continue
dc_dxts_idx
+=
1
else
:
if
diff_outputs
[
i
]
in
known_grads
:
known_grads
[
diff_outputs
[
i
]]
+=
dC_dXts
[
dc_dxts_idx
]
else
:
known_grads
[
diff_outputs
[
i
]]
=
dC_dXts
[
dc_dxts_idx
]
dc_dxts_idx
+=
1
dC_dinps_t
=
compute_all_gradients
(
known_grads
)
dC_dinps_t
=
compute_all_gradients
(
known_grads
)
# mask inputs that get no gradients
# mask inputs that get no gradients
for
dx
in
range
(
len
(
dC_dinps_t
)):
for
dx
in
range
(
len
(
dC_dinps_t
)):
if
not
dC_dinps_t
[
dx
]
:
if
dC_dinps_t
[
dx
]
is
None
:
dC_dinps_t
[
dx
]
=
pt
.
zeros_like
(
diff_inputs
[
dx
])
dC_dinps_t
[
dx
]
=
dC_dinps_t
[
dx
]
=
pt
.
zeros_like
(
diff_inputs
[
dx
])
else
:
else
:
disconnected_dC_dinps_t
[
dx
]
=
False
disconnected_dC_dinps_t
[
dx
]
=
False
for
Xt
,
Xt_placeholder
in
zip
(
for
Xt
,
Xt_placeholder
in
zip
(
...
@@ -2846,7 +2854,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2846,7 +2854,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for
idx
in
range
(
info
.
n_sit_sot
):
for
idx
in
range
(
info
.
n_sit_sot
):
mitmot_inp_taps
.
append
([
0
,
1
])
mitmot_inp_taps
.
append
([
0
,
1
])
mitmot_out_taps
.
append
([
1
])
mitmot_out_taps
.
append
([
1
])
through_shared
=
False
if
not
isinstance
(
dC_douts
[
idx
+
offset
]
.
type
,
DisconnectedType
):
if
not
isinstance
(
dC_douts
[
idx
+
offset
]
.
type
,
DisconnectedType
):
outer_inp_mitmot
.
append
(
dC_douts
[
idx
+
offset
][::
-
1
])
outer_inp_mitmot
.
append
(
dC_douts
[
idx
+
offset
][::
-
1
])
else
:
else
:
...
@@ -3007,9 +3014,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3007,9 +3014,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
name
=
f
"grad_of_{self.name}"
if
self
.
name
else
None
,
name
=
f
"grad_of_{self.name}"
if
self
.
name
else
None
,
allow_gc
=
self
.
allow_gc
,
allow_gc
=
self
.
allow_gc
,
)
)
outputs
=
local_op
(
*
outer_inputs
)
outputs
=
local_op
(
*
outer_inputs
,
return_list
=
True
)
if
not
isinstance
(
outputs
,
list
|
tuple
):
outputs
=
[
outputs
]
# Re-order the gradients correctly
# Re-order the gradients correctly
gradients
=
[
DisconnectedType
()()]
gradients
=
[
DisconnectedType
()()]
...
@@ -3095,7 +3100,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -3095,7 +3100,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
)
)
)
start
=
len
(
gradients
)
gradients
+=
[
DisconnectedType
()()
for
_
in
range
(
info
.
n_nit_sot
)]
gradients
+=
[
DisconnectedType
()()
for
_
in
range
(
info
.
n_nit_sot
)]
begin
=
end
begin
=
end
...
...
tests/scan/test_basic.py
浏览文件 @
4aea87c2
...
@@ -2128,6 +2128,57 @@ class TestScan:
...
@@ -2128,6 +2128,57 @@ class TestScan:
# TODO: We should test something about the Rop!
# TODO: We should test something about the Rop!
Rop
(
d_cost_wrt_pars
,
pars
,
p
)
Rop
(
d_cost_wrt_pars
,
pars
,
p
)
def
test_second_derivative_disconnected_cost_with_mit_mot
(
self
):
# This test is a regression test for a bug that was revealed
# when we computed the pushforward of a Scan gradient via two applications of pullback
seq
=
pt
.
vector
(
"seq"
,
shape
=
(
2
,))
z
=
pt
.
scalar
(
"z"
)
x0
=
pt
.
vector
(
"x0"
,
shape
=
(
2
,))
# When s is 1 and z is 2, xs[-1] is just a sneaky
# x ** 4 (after two nsteps)
# grad should be 4 * x ** 3
# and grad of grad should be 12 * x ** 2
def
step
(
s
,
xtm2
,
xtm1
,
z
):
return
s
*
((
xtm2
*
0
+
xtm1
)
**
2
)
*
(
z
/
2
)
xs
,
_
=
scan
(
step
,
sequences
=
[
seq
],
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
(
-
2
,
-
1
)}],
non_sequences
=
[
z
],
n_steps
=
2
,
)
last_x
=
xs
[
-
1
]
g_wrt_x0
,
g_wrt_z
,
g_wrt_seq
=
pt
.
grad
(
last_x
,
[
x0
,
z
,
seq
])
g
=
g_wrt_x0
.
sum
()
+
g_wrt_z
.
sum
()
*
0
+
g_wrt_seq
.
sum
()
*
0
assert
g
.
eval
({
seq
:
[
1
,
1
],
x0
:
[
1
,
1
],
z
:
2
})
==
4
gg
=
pt
.
grad
(
g
,
wrt
=
x0
)
.
sum
()
assert
gg
.
eval
({
seq
:
[
1
,
1
],
x0
:
[
1
,
1
],
z
:
2
})
==
12
assert
gg
.
eval
({
seq
:
[
2
,
2
],
x0
:
[
1
,
1
],
z
:
2
})
==
96
# Leave out z
g_wrt_x0
,
g_wrt_seq
=
pt
.
grad
(
last_x
,
[
x0
,
seq
])
g
=
g_wrt_x0
.
sum
()
+
g_wrt_seq
.
sum
()
*
0
gg
=
pt
.
grad
(
g
,
wrt
=
x0
)
.
sum
()
assert
gg
.
eval
({
seq
:
[
1
,
1
],
x0
:
[
1
,
1
],
z
:
2
})
==
12
assert
gg
.
eval
({
seq
:
[
2
,
2
],
x0
:
[
1
,
1
],
z
:
2
})
==
96
# Leave out seq
g_wrt_x0
,
g_wrt_z
=
pt
.
grad
(
last_x
,
[
x0
,
z
])
g
=
g_wrt_x0
.
sum
()
+
g_wrt_z
.
sum
()
*
0
gg
=
pt
.
grad
(
g
,
wrt
=
x0
)
.
sum
()
assert
gg
.
eval
({
seq
:
[
1
,
1
],
x0
:
[
1
,
1
],
z
:
2
})
==
12
assert
gg
.
eval
({
seq
:
[
1
,
1
],
x0
:
[
1
,
1
],
z
:
1
})
==
3
/
2
# Leave out z and seq
g_wrt_x0
=
pt
.
grad
(
last_x
,
x0
)
g
=
g_wrt_x0
.
sum
()
gg
=
pt
.
grad
(
g
,
wrt
=
x0
)
.
sum
()
assert
gg
.
eval
({
seq
:
[
1
,
1
],
x0
:
[
1
,
1
],
z
:
2
})
==
12
assert
gg
.
eval
({
seq
:
[
1
,
1
],
x0
:
[
1
,
1
],
z
:
1
})
==
3
/
2
@pytest.mark.skipif
(
@pytest.mark.skipif
(
not
config
.
cxx
,
reason
=
"G++ not available, so we need to skip this test."
not
config
.
cxx
,
reason
=
"G++ not available, so we need to skip this test."
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论