Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1aa9a396
提交
1aa9a396
authored
2月 11, 2025
作者:
jessegrabowski
提交者:
Jesse Grabowski
4月 19, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
New Ops related to LU decomposition
上级
ee884b87
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
595 行增加
和
2 行删除
+595
-2
slinalg.py
pytensor/tensor/slinalg.py
+420
-2
test_slinalg.py
tests/tensor/test_slinalg.py
+175
-0
没有找到文件。
pytensor/tensor/slinalg.py
浏览文件 @
1aa9a396
...
...
@@ -10,7 +10,9 @@ from numpy.exceptions import ComplexWarning
import
pytensor
import
pytensor.tensor
as
pt
from
pytensor.graph.basic
import
Apply
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.op
import
Op
from
pytensor.tensor
import
TensorLike
,
as_tensor_variable
from
pytensor.tensor
import
basic
as
ptb
...
...
@@ -225,6 +227,7 @@ class SolveBase(Op):
):
self
.
lower
=
lower
self
.
check_finite
=
check_finite
assert
b_ndim
in
(
1
,
2
)
self
.
b_ndim
=
b_ndim
if
b_ndim
==
1
:
...
...
@@ -302,10 +305,14 @@ class SolveBase(Op):
solve_op
=
type
(
self
)(
**
props_dict
)
b_bar
=
solve_op
(
A
.
T
,
c_bar
)
b_bar
=
solve_op
(
A
.
m
T
,
c_bar
)
# force outer product if vector second input
A_bar
=
-
ptm
.
outer
(
b_bar
,
c
)
if
c
.
ndim
==
1
else
-
b_bar
.
dot
(
c
.
T
)
if
props_dict
.
get
(
"unit_diagonal"
,
False
):
n
=
A_bar
.
shape
[
-
1
]
A_bar
=
A_bar
[
pt
.
arange
(
n
),
pt
.
arange
(
n
)]
.
set
(
pt
.
zeros
(
n
))
return
[
A_bar
,
b_bar
]
...
...
@@ -394,6 +401,411 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
)(
A
,
b
)
class
LU
(
Op
):
"""Decompose a matrix into lower and upper triangular matrices."""
__props__
=
(
"permute_l"
,
"overwrite_a"
,
"check_finite"
,
"p_indices"
)
def
__init__
(
self
,
*
,
permute_l
=
False
,
overwrite_a
=
False
,
check_finite
=
True
,
p_indices
=
False
):
if
permute_l
and
p_indices
:
raise
ValueError
(
"Only one of permute_l and p_indices can be True"
)
self
.
permute_l
=
permute_l
self
.
check_finite
=
check_finite
self
.
p_indices
=
p_indices
self
.
overwrite_a
=
overwrite_a
if
self
.
permute_l
:
# permute_l overrides p_indices in the scipy function. We can copy that behavior
self
.
gufunc_signature
=
"(m,m)->(m,m),(m,m)"
elif
self
.
p_indices
:
self
.
gufunc_signature
=
"(m,m)->(m),(m,m),(m,m)"
else
:
self
.
gufunc_signature
=
"(m,m)->(m,m),(m,m),(m,m)"
if
self
.
overwrite_a
:
self
.
destroy_map
=
{
0
:
[
0
]}
if
self
.
permute_l
else
{
1
:
[
0
]}
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
n
=
shapes
[
0
][
0
]
if
self
.
permute_l
:
return
[(
n
,
n
),
(
n
,
n
)]
elif
self
.
p_indices
:
return
[(
n
,),
(
n
,
n
),
(
n
,
n
)]
else
:
return
[(
n
,
n
),
(
n
,
n
),
(
n
,
n
)]
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
if
x
.
type
.
ndim
!=
2
:
raise
TypeError
(
f
"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
)
real_dtype
=
"f"
if
np
.
dtype
(
x
.
type
.
dtype
)
.
char
in
"fF"
else
"d"
p_dtype
=
"int32"
if
self
.
p_indices
else
np
.
dtype
(
real_dtype
)
L
=
tensor
(
shape
=
x
.
type
.
shape
,
dtype
=
x
.
type
.
dtype
)
U
=
tensor
(
shape
=
x
.
type
.
shape
,
dtype
=
x
.
type
.
dtype
)
if
self
.
permute_l
:
# In this case, L is actually P @ L
return
Apply
(
self
,
inputs
=
[
x
],
outputs
=
[
L
,
U
])
if
self
.
p_indices
:
p_indices
=
tensor
(
shape
=
(
x
.
type
.
shape
[
0
],),
dtype
=
p_dtype
)
return
Apply
(
self
,
inputs
=
[
x
],
outputs
=
[
p_indices
,
L
,
U
])
P
=
tensor
(
shape
=
x
.
type
.
shape
,
dtype
=
p_dtype
)
return
Apply
(
self
,
inputs
=
[
x
],
outputs
=
[
P
,
L
,
U
])
def
perform
(
self
,
node
,
inputs
,
outputs
):
[
A
]
=
inputs
out
=
scipy_linalg
.
lu
(
A
,
permute_l
=
self
.
permute_l
,
overwrite_a
=
self
.
overwrite_a
,
check_finite
=
self
.
check_finite
,
p_indices
=
self
.
p_indices
,
)
outputs
[
0
][
0
]
=
out
[
0
]
outputs
[
1
][
0
]
=
out
[
1
]
if
not
self
.
permute_l
:
# In all cases except permute_l, there are three returns
outputs
[
2
][
0
]
=
out
[
2
]
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
0
in
allowed_inplace_inputs
:
new_props
=
self
.
_props_dict
()
# type: ignore
new_props
[
"overwrite_a"
]
=
True
return
type
(
self
)(
**
new_props
)
else
:
return
self
def
L_op
(
self
,
inputs
:
Sequence
[
ptb
.
Variable
],
outputs
:
Sequence
[
ptb
.
Variable
],
output_grads
:
Sequence
[
ptb
.
Variable
],
)
->
list
[
ptb
.
Variable
]:
r"""
Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization
F. R. De Hoog, R.S. Anderssen, M. A. Lukas
"""
[
A
]
=
inputs
A
=
cast
(
TensorVariable
,
A
)
if
self
.
permute_l
:
# P has no gradient contribution (by assumption...), so PL_bar is the same as L_bar
L_bar
,
U_bar
=
output_grads
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
# We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass
P_or_indices
,
L
,
U
=
lu
(
# type: ignore
A
,
permute_l
=
False
,
check_finite
=
self
.
check_finite
,
p_indices
=
False
)
else
:
# In both other cases, there are 3 outputs. The first output will either be the permutation index itself,
# or indices that can be used to reconstruct the permutation matrix.
P_or_indices
,
L
,
U
=
outputs
_
,
L_bar
,
U_bar
=
output_grads
L_bar
=
(
L_bar
if
not
isinstance
(
L_bar
.
type
,
DisconnectedType
)
else
pt
.
zeros_like
(
A
)
)
U_bar
=
(
U_bar
if
not
isinstance
(
U_bar
.
type
,
DisconnectedType
)
else
pt
.
zeros_like
(
A
)
)
x1
=
ptb
.
tril
(
L
.
T
@
L_bar
,
k
=-
1
)
x2
=
ptb
.
triu
(
U_bar
@
U
.
T
)
LT_inv_x
=
solve_triangular
(
L
.
T
,
x1
+
x2
,
lower
=
False
,
unit_diagonal
=
True
)
# Where B = P.T @ A is a change of variable to avoid the permutation matrix in the gradient derivation
B_bar
=
solve_triangular
(
U
,
LT_inv_x
.
T
,
lower
=
False
)
.
T
if
not
self
.
p_indices
:
A_bar
=
P_or_indices
@
B_bar
else
:
A_bar
=
B_bar
[
P_or_indices
]
return
[
A_bar
]
def
lu
(
a
:
TensorLike
,
permute_l
=
False
,
check_finite
=
True
,
p_indices
=
False
,
overwrite_a
:
bool
=
False
,
)
->
(
tuple
[
TensorVariable
,
TensorVariable
,
TensorVariable
]
|
tuple
[
TensorVariable
,
TensorVariable
]
):
"""
Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix:
... math::
A = P L U
Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular.
Parameters
----------
a: TensorLike
Matrix to be factorized
permute_l: bool
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
be returned in this case, and PL will not be lower triangular.
check_finite: bool
Whether to check that the input matrix contains only finite numbers.
p_indices: bool
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
itself.
overwrite_a: bool
Ignored by Pytensor. Pytensor will always perform computation inplace if possible.
Returns
-------
P: TensorVariable
Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True.
L: TensorVariable
Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True.
U: TensorVariable
Upper triangular matrix
"""
return
cast
(
tuple
[
TensorVariable
,
TensorVariable
,
TensorVariable
]
|
tuple
[
TensorVariable
,
TensorVariable
],
Blockwise
(
LU
(
permute_l
=
permute_l
,
p_indices
=
p_indices
,
check_finite
=
check_finite
)
)(
a
),
)
class
PivotToPermutations
(
Op
):
__props__
=
(
"inverse"
,)
def
__init__
(
self
,
inverse
=
True
):
self
.
inverse
=
inverse
def
make_node
(
self
,
pivots
):
pivots
=
as_tensor_variable
(
pivots
)
if
pivots
.
ndim
!=
1
:
raise
ValueError
(
"PivotToPermutations only works on 1-D inputs"
)
permutations
=
pivots
.
type
.
clone
(
dtype
=
"int64"
)()
return
Apply
(
self
,
[
pivots
],
[
permutations
])
def
perform
(
self
,
node
,
inputs
,
outputs
):
[
pivots
]
=
inputs
p_inv
=
np
.
arange
(
len
(
pivots
),
dtype
=
pivots
.
dtype
)
for
i
in
range
(
len
(
pivots
)):
p_inv
[
i
],
p_inv
[
pivots
[
i
]]
=
p_inv
[
pivots
[
i
]],
p_inv
[
i
]
if
self
.
inverse
:
outputs
[
0
][
0
]
=
p_inv
else
:
outputs
[
0
][
0
]
=
np
.
argsort
(
p_inv
)
def
pivot_to_permutation
(
p
:
TensorLike
,
inverse
=
False
)
->
Variable
:
p
=
pt
.
as_tensor_variable
(
p
)
return
PivotToPermutations
(
inverse
=
inverse
)(
p
)
class
LUFactor
(
Op
):
__props__
=
(
"overwrite_a"
,
"check_finite"
)
gufunc_signature
=
"(m,m)->(m,m),(m)"
def
__init__
(
self
,
*
,
overwrite_a
=
False
,
check_finite
=
True
):
self
.
overwrite_a
=
overwrite_a
self
.
check_finite
=
check_finite
if
self
.
overwrite_a
:
self
.
destroy_map
=
{
1
:
[
0
]}
def
make_node
(
self
,
A
):
A
=
as_tensor_variable
(
A
)
if
A
.
type
.
ndim
!=
2
:
raise
TypeError
(
f
"LU only allowed on matrix (2-D) inputs, got {A.type.ndim}-D input"
)
LU
=
matrix
(
shape
=
A
.
type
.
shape
,
dtype
=
A
.
type
.
dtype
)
pivots
=
vector
(
shape
=
(
A
.
type
.
shape
[
0
],),
dtype
=
"int64"
)
return
Apply
(
self
,
[
A
],
[
LU
,
pivots
])
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
n
=
shapes
[
0
][
0
]
return
[(
n
,
n
),
(
n
,)]
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
0
in
allowed_inplace_inputs
:
new_props
=
self
.
_props_dict
()
# type: ignore
new_props
[
"overwrite_a"
]
=
True
return
type
(
self
)(
**
new_props
)
else
:
return
self
def
perform
(
self
,
node
,
inputs
,
outputs
):
A
=
inputs
[
0
]
LU
,
p
=
scipy_linalg
.
lu_factor
(
A
,
overwrite_a
=
self
.
overwrite_a
,
check_finite
=
self
.
check_finite
)
outputs
[
0
][
0
]
=
LU
outputs
[
1
][
0
]
=
p
def
L_op
(
self
,
inputs
,
outputs
,
output_gradients
):
[
A
]
=
inputs
LU_bar
,
_
=
output_gradients
LU
,
p_indices
=
outputs
eye
=
ptb
.
identity_like
(
A
)
L
=
cast
(
TensorVariable
,
ptb
.
tril
(
LU
,
k
=-
1
)
+
eye
)
U
=
cast
(
TensorVariable
,
ptb
.
triu
(
LU
))
p_indices
=
pivot_to_permutation
(
p_indices
,
inverse
=
False
)
# Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
L_bar
=
ptb
.
tril
(
LU_bar
,
k
=-
1
)
U_bar
=
ptb
.
triu
(
LU_bar
)
# From here we're in the same situation as the LU gradient derivation
x1
=
ptb
.
tril
(
L
.
T
@
L_bar
,
k
=-
1
)
x2
=
ptb
.
triu
(
U_bar
@
U
.
T
)
LT_inv_x
=
solve_triangular
(
L
.
T
,
x1
+
x2
,
lower
=
False
,
unit_diagonal
=
True
)
B_bar
=
solve_triangular
(
U
,
LT_inv_x
.
T
,
lower
=
False
)
.
T
A_bar
=
B_bar
[
p_indices
]
return
[
A_bar
]
def
lu_factor
(
a
:
TensorLike
,
*
,
check_finite
:
bool
=
True
,
overwrite_a
:
bool
=
False
,
)
->
tuple
[
TensorVariable
,
TensorVariable
]:
"""
LU factorization with partial pivoting.
Parameters
----------
a: TensorLike
Matrix to be factorized
check_finite: bool
Whether to check that the input matrix contains only finite numbers.
overwrite_a: bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
Returns
-------
LU: TensorVariable
LU decomposition of `a`
pivots: TensorVariable
An array of integers representin the pivot indices
"""
return
cast
(
tuple
[
TensorVariable
,
TensorVariable
],
Blockwise
(
LUFactor
(
check_finite
=
check_finite
))(
a
),
)
class
LUSolve
(
OpFromGraph
):
"""Solve a system of linear equations given the LU decomposition of the matrix."""
__props__
=
(
"trans"
,
"b_ndim"
,
"check_finite"
,
"overwrite_b"
)
def
__init__
(
self
,
inputs
:
list
[
Variable
],
outputs
:
list
[
Variable
],
trans
:
bool
=
False
,
b_ndim
:
int
|
None
=
None
,
check_finite
:
bool
=
False
,
overwrite_b
:
bool
=
False
,
**
kwargs
,
):
self
.
trans
=
trans
self
.
b_ndim
=
b_ndim
self
.
check_finite
=
check_finite
self
.
overwrite_b
=
overwrite_b
super
()
.
__init__
(
inputs
=
inputs
,
outputs
=
outputs
,
**
kwargs
)
def
lu_solve
(
LU_and_pivots
:
tuple
[
TensorLike
,
TensorLike
],
b
:
TensorLike
,
trans
:
bool
=
False
,
b_ndim
:
int
|
None
=
None
,
check_finite
:
bool
=
True
,
overwrite_b
:
bool
=
False
,
):
"""
Solve a system of linear equations given the LU decomposition of the matrix.
Parameters
----------
LU_and_pivots: tuple[TensorLike, TensorLike]
LU decomposition of the matrix, as returned by `lu_factor`
b: TensorLike
Right-hand side of the equation
trans: bool
If True, solve A^T x = b, instead of Ax = b. Default is False
b_ndim: int, optional
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
check_finite: bool
If True, check that the input matrices contain only finite numbers. Default is True.
overwrite_b: bool
Ignored by Pytensor. Pytensor will always compute inplace when possible.
"""
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
LU
,
pivots
=
LU_and_pivots
LU
,
pivots
,
b
=
map
(
pt
.
as_tensor_variable
,
[
LU
,
pivots
,
b
])
inv_permutation
=
pivot_to_permutation
(
pivots
,
inverse
=
True
)
x
=
b
[
inv_permutation
]
if
not
trans
else
b
x
=
solve_triangular
(
LU
,
x
,
lower
=
not
trans
,
unit_diagonal
=
not
trans
,
trans
=
trans
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
,
)
x
=
solve_triangular
(
LU
,
x
,
lower
=
trans
,
unit_diagonal
=
trans
,
trans
=
trans
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
,
)
x
=
x
[
pt
.
argsort
(
inv_permutation
)]
if
trans
else
x
return
x
class
SolveTriangular
(
SolveBase
):
"""Solve a system of linear equations."""
...
...
@@ -408,6 +820,9 @@ class SolveTriangular(SolveBase):
def
__init__
(
self
,
*
,
unit_diagonal
=
False
,
**
kwargs
):
if
kwargs
.
get
(
"overwrite_a"
,
False
):
raise
ValueError
(
"overwrite_a is not supported for SolverTriangulare"
)
# There's a naming inconsistency between solve_triangular (trans) and solve (transposed). Internally, we can use
# transpose everywhere, but expose the same API as scipy.linalg.solve_triangular
super
()
.
__init__
(
**
kwargs
)
self
.
unit_diagonal
=
unit_diagonal
...
...
@@ -1265,4 +1680,7 @@ __all__ = [
"solve_triangular"
,
"block_diag"
,
"cho_solve"
,
"lu"
,
"lu_factor"
,
"lu_solve"
,
]
tests/tensor/test_slinalg.py
浏览文件 @
1aa9a396
...
...
@@ -23,6 +23,10 @@ from pytensor.tensor.slinalg import (
cholesky
,
eigvalsh
,
expm
,
lu
,
lu_factor
,
lu_solve
,
pivot_to_permutation
,
solve
,
solve_continuous_lyapunov
,
solve_discrete_are
,
...
...
@@ -584,6 +588,177 @@ class TestCholeskySolve(utt.InferShapeTester):
assert
x
.
dtype
==
x_result
.
dtype
,
(
A_dtype
,
b_dtype
)
@pytest.mark.parametrize
(
"permute_l, p_indices"
,
[(
False
,
True
),
(
True
,
False
),
(
False
,
False
)],
ids
=
[
"PL"
,
"p_indices"
,
"P"
],
)
@pytest.mark.parametrize
(
"complex"
,
[
False
,
True
],
ids
=
[
"real"
,
"complex"
])
@pytest.mark.parametrize
(
"shape"
,
[(
3
,
5
,
5
),
(
5
,
5
)],
ids
=
[
"batched"
,
"not_batched"
])
def
test_lu_decomposition
(
permute_l
:
bool
,
p_indices
:
bool
,
complex
:
bool
,
shape
:
tuple
[
int
]
):
dtype
=
config
.
floatX
if
not
complex
else
f
"complex{int(config.floatX[-2:]) * 2}"
A
=
tensor
(
"A"
,
shape
=
shape
,
dtype
=
dtype
)
out
=
lu
(
A
,
permute_l
=
permute_l
,
p_indices
=
p_indices
)
f
=
pytensor
.
function
([
A
],
out
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
x
=
rng
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
)
if
complex
:
x
=
x
+
1
j
*
rng
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
)
out
=
f
(
x
)
if
permute_l
:
PL
,
U
=
out
elif
p_indices
:
p
,
L
,
U
=
out
if
len
(
shape
)
==
2
:
P
=
np
.
eye
(
5
)[
p
]
else
:
P
=
np
.
stack
([
np
.
eye
(
5
)[
idx
]
for
idx
in
p
])
PL
=
np
.
einsum
(
"...nk,...km->...nm"
,
P
,
L
)
else
:
P
,
L
,
U
=
out
PL
=
np
.
einsum
(
"...nk,...km->...nm"
,
P
,
L
)
x_rebuilt
=
np
.
einsum
(
"...nk,...km->...nm"
,
PL
,
U
)
np
.
testing
.
assert_allclose
(
x
,
x_rebuilt
,
atol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-4
,
rtol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-4
,
)
scipy_out
=
scipy
.
linalg
.
lu
(
x
,
permute_l
=
permute_l
,
p_indices
=
p_indices
)
for
a
,
b
in
zip
(
out
,
scipy_out
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
a
,
b
)
@pytest.mark.parametrize
(
"grad_case"
,
[
0
,
1
,
2
],
ids
=
[
"dU_only"
,
"dL_only"
,
"dU_and_dL"
]
)
@pytest.mark.parametrize
(
"permute_l, p_indices"
,
[(
True
,
False
),
(
False
,
True
),
(
False
,
False
)],
ids
=
[
"PL"
,
"p_indices"
,
"P"
],
)
@pytest.mark.parametrize
(
"shape"
,
[(
3
,
5
,
5
),
(
5
,
5
)],
ids
=
[
"batched"
,
"not_batched"
])
def
test_lu_grad
(
grad_case
,
permute_l
,
p_indices
,
shape
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A_value
=
rng
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
)
def
f_pt
(
A
):
# lu returns either (P_or_index, L, U) or (PL, U), depending on settings
out
=
lu
(
A
,
permute_l
=
permute_l
,
p_indices
=
p_indices
,
check_finite
=
False
)
match
grad_case
:
case
0
:
return
out
[
-
1
]
.
sum
()
case
1
:
return
out
[
-
2
]
.
sum
()
case
2
:
return
out
[
-
1
]
.
sum
()
+
out
[
-
2
]
.
sum
()
utt
.
verify_grad
(
f_pt
,
[
A_value
],
rng
=
rng
)
@pytest.mark.parametrize
(
"inverse"
,
[
True
,
False
],
ids
=
[
"inverse"
,
"no_inverse"
])
def
test_pivot_to_permutation
(
inverse
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
_
,
pivots
=
scipy
.
linalg
.
lu_factor
(
A_val
)
perm_idx
,
*
_
=
scipy
.
linalg
.
lu
(
A_val
,
p_indices
=
True
)
if
not
inverse
:
perm_idx_pt
=
pivot_to_permutation
(
pivots
,
inverse
=
False
)
.
eval
()
np
.
testing
.
assert_array_equal
(
perm_idx_pt
,
perm_idx
)
else
:
p_inv_pt
=
pivot_to_permutation
(
pivots
,
inverse
=
True
)
.
eval
()
np
.
testing
.
assert_array_equal
(
p_inv_pt
,
np
.
argsort
(
perm_idx
))
class
TestLUSolve
(
utt
.
InferShapeTester
):
@staticmethod
def
factor_and_solve
(
A
,
b
,
sum
=
False
,
**
lu_kwargs
):
lu_and_pivots
=
lu_factor
(
A
)
x
=
lu_solve
(
lu_and_pivots
,
b
,
**
lu_kwargs
)
if
not
sum
:
return
x
return
x
.
sum
()
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,),
(
5
,
5
)],
ids
=
[
"b_vec"
,
"b_matrix"
])
@pytest.mark.parametrize
(
"trans"
,
[
True
,
False
],
ids
=
[
"x_T"
,
"x"
])
def
test_lu_solve
(
self
,
b_shape
:
tuple
[
int
],
trans
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
pt
.
tensor
(
"A"
,
shape
=
(
5
,
5
))
b
=
pt
.
tensor
(
"b"
,
shape
=
b_shape
)
A_val
=
(
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
+
np
.
eye
(
5
,
dtype
=
config
.
floatX
)
*
0.5
)
b_val
=
rng
.
normal
(
size
=
b_shape
)
.
astype
(
config
.
floatX
)
x
=
self
.
factor_and_solve
(
A
,
b
,
trans
=
trans
,
sum
=
False
)
f
=
pytensor
.
function
([
A
,
b
],
x
)
x_pt
=
f
(
A_val
.
copy
(),
b_val
.
copy
())
x_sp
=
scipy
.
linalg
.
lu_solve
(
scipy
.
linalg
.
lu_factor
(
A_val
.
copy
()),
b_val
.
copy
(),
trans
=
trans
)
np
.
testing
.
assert_allclose
(
x_pt
,
x_sp
)
def
T
(
x
):
if
trans
:
return
x
.
T
return
x
np
.
testing
.
assert_allclose
(
T
(
A_val
)
@
x_pt
,
b_val
,
atol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-4
,
rtol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-4
,
)
np
.
testing
.
assert_allclose
(
x_pt
,
x_sp
)
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,),
(
5
,
5
)],
ids
=
[
"b_vec"
,
"b_matrix"
])
@pytest.mark.parametrize
(
"trans"
,
[
True
,
False
],
ids
=
[
"x_T"
,
"x"
])
def
test_lu_solve_gradient
(
self
,
b_shape
:
tuple
[
int
],
trans
:
bool
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
b_shape
)
.
astype
(
config
.
floatX
)
test_fn
=
functools
.
partial
(
self
.
factor_and_solve
,
sum
=
True
,
trans
=
trans
)
utt
.
verify_grad
(
test_fn
,
[
A_val
,
b_val
],
3
,
rng
)
def
test_lu_factor
():
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
matrix
()
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
f
=
pytensor
.
function
([
A
],
lu_factor
(
A
))
LU
,
pt_p_idx
=
f
(
A_val
)
sp_LU
,
sp_p_idx
=
scipy
.
linalg
.
lu_factor
(
A_val
)
np
.
testing
.
assert_allclose
(
LU
,
sp_LU
)
np
.
testing
.
assert_allclose
(
pt_p_idx
,
sp_p_idx
)
utt
.
verify_grad
(
lambda
A
:
lu_factor
(
A
)[
0
]
.
sum
(),
[
A_val
],
rng
=
rng
,
)
def
test_cho_solve
():
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
matrix
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论