Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b7c952bf
提交
b7c952bf
authored
2月 13, 2024
作者:
jessegrabowski
提交者:
Ricardo Vieira
4月 28, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add gradient for `SVD`
上级
eb18f0ea
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
193 行增加
和
3 行删除
+193
-3
nlinalg.py
pytensor/tensor/nlinalg.py
+118
-3
test_nlinalg.py
tests/tensor/test_nlinalg.py
+75
-0
没有找到文件。
pytensor/tensor/nlinalg.py
浏览文件 @
b7c952bf
import
warnings
import
warnings
from
collections.abc
import
Callable
from
collections.abc
import
Callable
,
Sequence
from
functools
import
partial
from
functools
import
partial
from
typing
import
Literal
from
typing
import
Literal
,
cast
import
numpy
as
np
import
numpy
as
np
from
numpy.core.numeric
import
normalize_axis_tuple
# type: ignore
from
numpy.core.numeric
import
normalize_axis_tuple
# type: ignore
...
@@ -15,7 +15,7 @@ from pytensor.tensor import basic as ptb
...
@@ -15,7 +15,7 @@ from pytensor.tensor import basic as ptb
from
pytensor.tensor
import
math
as
ptm
from
pytensor.tensor
import
math
as
ptm
from
pytensor.tensor.basic
import
as_tensor_variable
,
diagonal
from
pytensor.tensor.basic
import
as_tensor_variable
,
diagonal
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.type
import
dvector
,
lscalar
,
matrix
,
scalar
,
vector
from
pytensor.tensor.type
import
Variable
,
dvector
,
lscalar
,
matrix
,
scalar
,
vector
class
MatrixPinv
(
Op
):
class
MatrixPinv
(
Op
):
...
@@ -597,6 +597,121 @@ class SVD(Op):
...
@@ -597,6 +597,121 @@ class SVD(Op):
else
:
else
:
return
[
s_shape
]
return
[
s_shape
]
def
L_op
(
self
,
inputs
:
Sequence
[
Variable
],
outputs
:
Sequence
[
Variable
],
output_grads
:
Sequence
[
Variable
],
)
->
list
[
Variable
]:
"""
Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here:
https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194
And the mxnet implementation described in ..[1]
References
----------
.. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017).
"""
def
s_grad_only
(
U
:
ptb
.
TensorVariable
,
VT
:
ptb
.
TensorVariable
,
ds
:
ptb
.
TensorVariable
)
->
list
[
Variable
]:
A_bar
=
(
U
.
conj
()
*
ds
[
...
,
None
,
:])
@
VT
return
[
A_bar
]
(
A
,)
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
inputs
)
if
not
self
.
compute_uv
:
# We need all the components of the SVD to compute the gradient of A even if we only use the singular values
# in the cost function.
U
,
_
,
VT
=
svd
(
A
,
full_matrices
=
False
,
compute_uv
=
True
)
ds
=
cast
(
ptb
.
TensorVariable
,
output_grads
[
0
])
return
s_grad_only
(
U
,
VT
,
ds
)
elif
self
.
full_matrices
:
raise
NotImplementedError
(
"Gradient of svd not implemented for full_matrices=True"
)
else
:
U
,
s
,
VT
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
outputs
)
# Handle disconnected inputs
# If a user asked for all the matrices but then only used a subset in the cost function, the unused outputs
# will be DisconnectedType. We replace DisconnectedTypes with zero matrices of the correct shapes.
new_output_grads
=
[]
is_disconnected
=
[
isinstance
(
x
.
type
,
DisconnectedType
)
for
x
in
output_grads
]
if
all
(
is_disconnected
):
# This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient
# graph if its fully disconnected. It is included for completeness.
return
[
DisconnectedType
()()]
# pragma: no cover
elif
is_disconnected
==
[
True
,
False
,
True
]:
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
# needing to re-compoute U and VT
ds
=
cast
(
ptb
.
TensorVariable
,
output_grads
[
1
])
return
s_grad_only
(
U
,
VT
,
ds
)
for
disconnected
,
output_grad
,
output
in
zip
(
is_disconnected
,
output_grads
,
[
U
,
s
,
VT
]
):
if
disconnected
:
new_output_grads
.
append
(
output
.
zeros_like
())
else
:
new_output_grads
.
append
(
output_grad
)
(
dU
,
ds
,
dVT
)
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
new_output_grads
)
V
=
VT
.
T
dV
=
dVT
.
T
m
,
n
=
A
.
shape
[
-
2
:]
k
=
ptm
.
min
((
m
,
n
))
eye
=
ptb
.
eye
(
k
)
def
h
(
t
):
"""
Approximation of s_i ** 2 - s_j ** 2, from .. [1].
Robust to identical singular values (singular matrix input), although
gradients are still wrong in this case.
"""
eps
=
1e-8
# sign(0) = 0 in pytensor, which defeats the whole purpose of this function
sign_t
=
ptb
.
where
(
ptm
.
eq
(
t
,
0
),
1
,
ptm
.
sign
(
t
))
return
ptm
.
maximum
(
ptm
.
abs
(
t
),
eps
)
*
sign_t
numer
=
ptb
.
ones
((
k
,
k
))
-
eye
denom
=
h
(
s
[
None
]
-
s
[:,
None
])
*
h
(
s
[
None
]
+
s
[:,
None
])
E
=
numer
/
denom
utgu
=
U
.
T
@
dU
vtgv
=
VT
@
dV
A_bar
=
(
E
*
(
utgu
-
utgu
.
conj
()
.
T
))
*
s
[
...
,
None
,
:]
A_bar
=
A_bar
+
eye
*
ds
[
...
,
:,
None
]
A_bar
=
A_bar
+
s
[
...
,
:,
None
]
*
(
E
*
(
vtgv
-
vtgv
.
conj
()
.
T
))
A_bar
=
U
.
conj
()
@
A_bar
@
VT
A_bar
=
ptb
.
switch
(
ptm
.
eq
(
m
,
n
),
A_bar
,
ptb
.
switch
(
ptm
.
lt
(
m
,
n
),
A_bar
+
(
U
/
s
[
...
,
None
,
:]
@
dVT
@
(
ptb
.
eye
(
n
)
-
V
@
V
.
conj
()
.
T
)
)
.
conj
(),
A_bar
+
(
V
/
s
[
...
,
None
,
:]
@
dU
.
T
@
(
ptb
.
eye
(
m
)
-
U
@
U
.
conj
()
.
T
))
.
T
,
),
)
return
[
A_bar
]
def
svd
(
a
,
full_matrices
:
bool
=
True
,
compute_uv
:
bool
=
True
):
def
svd
(
a
,
full_matrices
:
bool
=
True
,
compute_uv
:
bool
=
True
):
"""
"""
...
...
tests/tensor/test_nlinalg.py
浏览文件 @
b7c952bf
...
@@ -8,6 +8,7 @@ from numpy.testing import assert_array_almost_equal
...
@@ -8,6 +8,7 @@ from numpy.testing import assert_array_almost_equal
import
pytensor
import
pytensor
from
pytensor
import
function
from
pytensor
import
function
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.tensor.basic
import
as_tensor_variable
from
pytensor.tensor.math
import
_allclose
from
pytensor.tensor.math
import
_allclose
from
pytensor.tensor.nlinalg
import
(
from
pytensor.tensor.nlinalg
import
(
SVD
,
SVD
,
...
@@ -215,6 +216,80 @@ class TestSvd(utt.InferShapeTester):
...
@@ -215,6 +216,80 @@ class TestSvd(utt.InferShapeTester):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
self
.
_compile_and_check
([
A
],
outputs
,
[
A_v
],
self
.
op_class
,
warn
=
False
)
self
.
_compile_and_check
([
A
],
outputs
,
[
A_v
],
self
.
op_class
,
warn
=
False
)
@pytest.mark.parametrize
(
"compute_uv, full_matrices, gradient_test_case"
,
[(
False
,
False
,
0
)]
+
[(
True
,
False
,
i
)
for
i
in
range
(
8
)]
+
[(
True
,
True
,
i
)
for
i
in
range
(
8
)],
ids
=
(
[
"compute_uv=False, full_matrices=False"
]
+
[
f
"compute_uv=True, full_matrices=False, gradient={grad}"
for
grad
in
[
"U"
,
"s"
,
"V"
,
"U+s"
,
"s+V"
,
"U+V"
,
"U+s+V"
,
"None"
]
]
+
[
f
"compute_uv=True, full_matrices=True, gradient={grad}"
for
grad
in
[
"U"
,
"s"
,
"V"
,
"U+s"
,
"s+V"
,
"U+V"
,
"U+s+V"
,
"None"
]
]
),
)
@pytest.mark.parametrize
(
"shape"
,
[(
3
,
3
),
(
4
,
3
),
(
3
,
4
)],
ids
=
[
"(3,3)"
,
"(4,3)"
,
"(3,4)"
]
)
@pytest.mark.parametrize
(
"batched"
,
[
True
,
False
],
ids
=
[
"batched=True"
,
"batched=False"
]
)
def
test_grad
(
self
,
compute_uv
,
full_matrices
,
gradient_test_case
,
shape
,
batched
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
if
batched
:
shape
=
(
4
,
*
shape
)
A_v
=
self
.
rng
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
)
if
full_matrices
:
with
pytest
.
raises
(
NotImplementedError
,
match
=
"Gradient of svd not implemented for full_matrices=True"
,
):
U
,
s
,
V
=
svd
(
self
.
A
,
compute_uv
=
compute_uv
,
full_matrices
=
full_matrices
)
pytensor
.
grad
(
s
.
sum
(),
self
.
A
)
elif
compute_uv
:
def
svd_fn
(
A
,
case
=
0
):
U
,
s
,
V
=
svd
(
A
,
compute_uv
=
compute_uv
,
full_matrices
=
full_matrices
)
if
case
==
0
:
return
U
.
sum
()
elif
case
==
1
:
return
s
.
sum
()
elif
case
==
2
:
return
V
.
sum
()
elif
case
==
3
:
return
U
.
sum
()
+
s
.
sum
()
elif
case
==
4
:
return
s
.
sum
()
+
V
.
sum
()
elif
case
==
5
:
return
U
.
sum
()
+
V
.
sum
()
elif
case
==
6
:
return
U
.
sum
()
+
s
.
sum
()
+
V
.
sum
()
elif
case
==
7
:
# All inputs disconnected
return
as_tensor_variable
(
3.0
)
utt
.
verify_grad
(
partial
(
svd_fn
,
case
=
gradient_test_case
),
[
A_v
],
rng
=
rng
,
)
else
:
utt
.
verify_grad
(
partial
(
svd
,
compute_uv
=
compute_uv
,
full_matrices
=
full_matrices
),
[
A_v
],
rng
=
rng
,
)
def
test_tensorsolve
():
def
test_tensorsolve
():
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论