Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
79961a62
提交
79961a62
authored
11月 24, 2021
作者:
Fabian Hartmann
提交者:
Brandon T. Willard
12月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add a SolveTriangular Op
`Solve` has also been changed to match SciPy.
上级
6fce270b
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
315 行增加
和
169 行删除
+315
-169
slinalg.py
aesara/tensor/slinalg.py
+149
-71
test_numba.py
tests/link/test_numba.py
+27
-2
test_slinalg.py
tests/tensor/test_slinalg.py
+139
-96
没有找到文件。
aesara/tensor/slinalg.py
浏览文件 @
79961a62
import
logging
import
warnings
from
typing
import
Union
import
numpy
as
np
import
scipy.linalg
...
...
@@ -11,6 +12,7 @@ from aesara.tensor import as_tensor_variable
from
aesara.tensor
import
basic
as
aet
from
aesara.tensor
import
math
as
atm
from
aesara.tensor.type
import
matrix
,
tensor
,
vector
from
aesara.tensor.var
import
TensorVariable
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -259,93 +261,52 @@ def cho_solve(c_and_lower, b, check_finite=True):
return
CholeskySolve
(
lower
=
lower
,
check_finite
=
check_finite
)(
A
,
b
)
class
Solve
(
Op
):
"""
Solve a system of linear equations.
For on CPU and GPU.
"""
class
SolveBase
(
Op
):
"""Base class for `scipy.linalg` matrix equation solvers."""
__props__
=
(
"assume_a"
,
"lower"
,
"check_finite"
,
# "transposed"
"check_finite"
,
)
def
__init__
(
self
,
assume_a
=
"gen"
,
lower
=
False
,
check_finite
=
True
,
# transposed=False
check_finite
=
True
,
):
if
assume_a
not
in
(
"gen"
,
"sym"
,
"her"
,
"pos"
):
raise
ValueError
(
f
"{assume_a} is not a recognized matrix structure"
)
self
.
assume_a
=
assume_a
self
.
lower
=
lower
self
.
check_finite
=
check_finite
# self.transposed = transposed
def
__repr__
(
self
):
return
"Solve{
%
s}"
%
str
(
self
.
_props
())
def
perform
(
self
,
node
,
inputs
,
outputs
):
pass
def
make_node
(
self
,
A
,
b
):
A
=
as_tensor_variable
(
A
)
b
=
as_tensor_variable
(
b
)
assert
A
.
ndim
==
2
assert
b
.
ndim
in
[
1
,
2
]
# infer dtype by solving the most simple
# case with (1, 1) matrices
if
A
.
ndim
!=
2
:
raise
ValueError
(
f
"`A` must be a matrix; got {A.type} instead."
)
if
b
.
ndim
not
in
[
1
,
2
]:
raise
ValueError
(
f
"`b` must be a matrix or a vector; got {b.type} instead."
)
# Infer dtype by solving the most simple case with 1x1 matrices
o_dtype
=
scipy
.
linalg
.
solve
(
np
.
eye
(
1
)
.
astype
(
A
.
dtype
),
np
.
eye
(
1
)
.
astype
(
b
.
dtype
)
)
.
dtype
x
=
tensor
(
broadcastable
=
b
.
broadcastable
,
dtype
=
o_dtype
)
return
Apply
(
self
,
[
A
,
b
],
[
x
])
def
perform
(
self
,
node
,
inputs
,
output_storage
):
A
,
b
=
inputs
if
self
.
assume_a
!=
"gen"
:
# if self.transposed:
# if self.assume_a == "her":
# trans = "C"
# else:
# trans = "T"
# else:
# trans = "N"
rval
=
scipy
.
linalg
.
solve_triangular
(
A
,
b
,
lower
=
self
.
lower
,
check_finite
=
self
.
check_finite
,
# trans=trans
)
else
:
rval
=
scipy
.
linalg
.
solve
(
A
,
b
,
assume_a
=
self
.
assume_a
,
lower
=
self
.
lower
,
check_finite
=
self
.
check_finite
,
# transposed=self.transposed,
)
output_storage
[
0
][
0
]
=
rval
# computes shape of x where x = inv(A) * b
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
Ashape
,
Bshape
=
shapes
rows
=
Ashape
[
1
]
if
len
(
Bshape
)
==
1
:
# b is a Vector
if
len
(
Bshape
)
==
1
:
return
[(
rows
,)]
else
:
cols
=
Bshape
[
1
]
# b is a Matrix
cols
=
Bshape
[
1
]
return
[(
rows
,
cols
)]
def
L_op
(
self
,
inputs
,
outputs
,
output_gradients
):
r"""
Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
Symbolic expression for updates taken from [#]_.
...
...
@@ -364,31 +325,148 @@ class Solve(Op):
# We need to return (dC/d[inv(A)], dC/db)
c_bar
=
output_gradients
[
0
]
trans_solve_op
=
Solve
(
assume_a
=
self
.
assume_a
,
check_finite
=
self
.
check_finite
,
lower
=
not
self
.
lower
,
trans_solve_op
=
type
(
self
)(
**
{
k
:
(
not
getattr
(
self
,
k
)
if
k
==
"lower"
else
getattr
(
self
,
k
))
for
k
in
self
.
__props__
}
)
b_bar
=
trans_solve_op
(
A
.
T
,
c_bar
)
# force outer product if vector second input
A_bar
=
-
atm
.
outer
(
b_bar
,
c
)
if
c
.
ndim
==
1
else
-
b_bar
.
dot
(
c
.
T
)
if
self
.
assume_a
!=
"gen"
:
if
self
.
lower
:
A_bar
=
aet
.
tril
(
A_bar
)
else
:
A_bar
=
aet
.
triu
(
A_bar
)
return
[
A_bar
,
b_bar
]
def
__repr__
(
self
):
return
f
"{type(self).__name__}{self._props()}"
class
SolveTriangular
(
SolveBase
):
"""Solve a system of linear equations."""
__props__
=
(
"lower"
,
"trans"
,
"unit_diagonal"
,
"check_finite"
,
)
def
__init__
(
self
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
check_finite
=
True
,
):
super
()
.
__init__
(
lower
=
lower
,
check_finite
=
check_finite
)
self
.
trans
=
trans
self
.
unit_diagonal
=
unit_diagonal
def
perform
(
self
,
node
,
inputs
,
outputs
):
A
,
b
=
inputs
outputs
[
0
][
0
]
=
scipy
.
linalg
.
solve_triangular
(
A
,
b
,
lower
=
self
.
lower
,
trans
=
self
.
trans
,
unit_diagonal
=
self
.
unit_diagonal
,
check_finite
=
self
.
check_finite
,
)
def
L_op
(
self
,
inputs
,
outputs
,
output_gradients
):
res
=
super
()
.
L_op
(
inputs
,
outputs
,
output_gradients
)
if
self
.
lower
:
res
[
0
]
=
aet
.
tril
(
res
[
0
])
else
:
res
[
0
]
=
aet
.
triu
(
res
[
0
])
return
res
solvetriangular
=
SolveTriangular
()
def
solve_triangular
(
a
:
TensorVariable
,
b
:
TensorVariable
,
trans
:
Union
[
int
,
str
]
=
0
,
lower
:
bool
=
False
,
unit_diagonal
:
bool
=
False
,
check_finite
:
bool
=
True
,
)
->
TensorVariable
:
"""Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix.
Parameters
----------
a
Square input data
b
Input data for the right hand side.
lower : bool, optional
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
trans: {0, 1, 2, ‘N’, ‘T’, ‘C’}, optional
Type of system to solve:
trans system
0 or 'N' a x = b
1 or 'T' a^T x = b
2 or 'C' a^H x = b
unit_diagonal: bool, optional
If True, diagonal elements of `a` are assumed to be 1 and will not be referenced.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
"""
return
SolveTriangular
(
lower
=
lower
,
trans
=
trans
,
unit_diagonal
=
unit_diagonal
,
check_finite
=
check_finite
,
)(
a
,
b
)
class
Solve
(
SolveBase
):
"""
Solve a system of linear equations.
For on CPU and GPU.
"""
__props__
=
(
"assume_a"
,
"lower"
,
"check_finite"
,
)
def
__init__
(
self
,
assume_a
=
"gen"
,
lower
=
False
,
check_finite
=
True
,
):
if
assume_a
not
in
(
"gen"
,
"sym"
,
"her"
,
"pos"
):
raise
ValueError
(
f
"{assume_a} is not a recognized matrix structure"
)
super
()
.
__init__
(
lower
=
lower
,
check_finite
=
check_finite
)
self
.
assume_a
=
assume_a
def
perform
(
self
,
node
,
inputs
,
outputs
):
a
,
b
=
inputs
outputs
[
0
][
0
]
=
scipy
.
linalg
.
solve
(
a
=
a
,
b
=
b
,
lower
=
self
.
lower
,
check_finite
=
self
.
check_finite
,
assume_a
=
self
.
assume_a
,
)
solve
=
Solve
()
def
solve
(
a
,
b
,
assume_a
=
"gen"
,
lower
=
False
,
check_finite
=
True
):
"""
Solves the linear equation set ``a * x = b`` for the unknown ``x``
for square ``a`` matrix.
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the
corresponding string to ``assume_a`` key chooses the dedicated solver.
...
...
@@ -432,8 +510,8 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True):
# TODO: These are deprecated; emit a warning
solve_lower_triangular
=
Solve
(
assume_a
=
"sym"
,
lower
=
True
)
solve_upper_triangular
=
Solve
(
assume_a
=
"sym"
,
lower
=
False
)
solve_lower_triangular
=
Solve
Triangular
(
lower
=
True
)
solve_upper_triangular
=
Solve
Triangular
(
lower
=
False
)
solve_symmetric
=
Solve
(
assume_a
=
"sym"
)
# TODO: Optimizations to replace multiplication by matrix inverse
...
...
tests/link/test_numba.py
浏览文件 @
79961a62
...
...
@@ -2174,6 +2174,31 @@ def test_Cholesky(x, lower, exc):
"gen"
,
None
,
),
],
)
def
test_Solve
(
A
,
x
,
lower
,
exc
):
g
=
slinalg
.
Solve
(
lower
)(
A
,
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
compare_numba_and_py
(
g_fg
,
[
i
.
tag
.
test_value
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
(
SharedVariable
,
Constant
))
],
)
@pytest.mark.parametrize
(
"A, x, lower, exc"
,
[
(
set_test_value
(
aet
.
dmatrix
(),
...
...
@@ -2185,8 +2210,8 @@ def test_Cholesky(x, lower, exc):
),
],
)
def
test_Solve
(
A
,
x
,
lower
,
exc
):
g
=
slinalg
.
Solve
(
lower
)(
A
,
x
)
def
test_Solve
Triangular
(
A
,
x
,
lower
,
exc
):
g
=
slinalg
.
Solve
Triangular
(
lower
)(
A
,
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
...
...
tests/tensor/test_slinalg.py
浏览文件 @
79961a62
import
functools
import
itertools
import
numpy
as
np
...
...
@@ -14,12 +15,15 @@ from aesara.tensor.slinalg import (
CholeskyGrad
,
CholeskySolve
,
Solve
,
SolveBase
,
SolveTriangular
,
cho_solve
,
cholesky
,
eigvalsh
,
expm
,
kron
,
solve
,
solve_triangular
,
)
from
aesara.tensor.type
import
dmatrix
,
matrix
,
tensor
,
vector
from
tests
import
unittest_tools
as
utt
...
...
@@ -170,122 +174,107 @@ def test_eigvalsh_grad():
)
class
TestSolve
(
utt
.
InferShapeTester
):
def
setup_method
(
self
):
self
.
op_class
=
Solve
self
.
op
=
Solve
()
super
()
.
setup_method
()
def
test_infer_shape
(
self
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
class
TestSolveBase
(
utt
.
InferShapeTester
):
@pytest.mark.parametrize
(
"A_func, b_func, error_message"
,
[
(
vector
,
matrix
,
"`A` must be a matrix.*"
),
(
functools
.
partial
(
tensor
,
dtype
=
"floatX"
,
broadcastable
=
(
False
,)
*
3
),
matrix
,
"`A` must be a matrix.*"
,
),
(
matrix
,
functools
.
partial
(
tensor
,
dtype
=
"floatX"
,
broadcastable
=
(
False
,)
*
3
),
"`b` must be a matrix or a vector.*"
,
),
],
)
def
test_make_node
(
self
,
A_func
,
b_func
,
error_message
):
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
with
pytest
.
raises
(
ValueError
,
match
=
error_message
):
A
=
A_func
()
b
=
b_func
()
SolveBase
()(
A
,
b
)
def
test__repr__
(
self
):
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
matrix
()
b
=
matrix
()
self
.
_compile_and_check
(
[
A
,
b
],
# aesara.function inputs
[
self
.
op
(
A
,
b
)],
# aesara.function outputs
# A must be square
[
np
.
asarray
(
rng
.
random
((
5
,
5
)),
dtype
=
config
.
floatX
),
np
.
asarray
(
rng
.
random
((
5
,
1
)),
dtype
=
config
.
floatX
),
],
self
.
op_class
,
warn
=
False
,
)
y
=
SolveBase
()(
A
,
b
)
assert
y
.
__repr__
()
==
"SolveBase{lower=False, check_finite=True}.0"
class
TestSolve
(
utt
.
InferShapeTester
):
def
test__init__
(
self
):
with
pytest
.
raises
(
ValueError
)
as
excinfo
:
Solve
(
assume_a
=
"test"
)
assert
"is not a recognized matrix structure"
in
str
(
excinfo
.
value
)
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,
1
),
(
5
,)])
def
test_infer_shape
(
self
,
b_shape
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
matrix
()
b
=
vector
()
b_val
=
np
.
asarray
(
rng
.
random
(
b_shape
),
dtype
=
config
.
floatX
)
b
=
aet
.
as_tensor_variable
(
b_val
)
.
type
()
self
.
_compile_and_check
(
[
A
,
b
],
# aesara.function inputs
[
self
.
op
(
A
,
b
)],
# aesara.function outputs
# A must be square
[
A
,
b
],
[
solve
(
A
,
b
)],
[
np
.
asarray
(
rng
.
random
((
5
,
5
)),
dtype
=
config
.
floatX
),
np
.
asarray
(
rng
.
random
((
5
)),
dtype
=
config
.
floatX
)
,
b_val
,
],
self
.
op_class
,
Solve
,
warn
=
False
,
)
def
test_
solve_
correctness
(
self
):
def
test_correctness
(
self
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
matrix
()
b
=
matrix
()
y
=
s
elf
.
op
(
A
,
b
)
y
=
s
olve
(
A
,
b
)
gen_solve_func
=
aesara
.
function
([
A
,
b
],
y
)
cholesky_lower
=
Cholesky
(
lower
=
True
)
L
=
cholesky_lower
(
A
)
y_lower
=
self
.
op
(
L
,
b
)
lower_solve_func
=
aesara
.
function
([
L
,
b
],
y_lower
)
cholesky_upper
=
Cholesky
(
lower
=
False
)
U
=
cholesky_upper
(
A
)
y_upper
=
self
.
op
(
U
,
b
)
upper_solve_func
=
aesara
.
function
([
U
,
b
],
y_upper
)
b_val
=
np
.
asarray
(
rng
.
random
((
5
,
1
)),
dtype
=
config
.
floatX
)
# 1-test general case
A_val
=
np
.
asarray
(
rng
.
random
((
5
,
5
)),
dtype
=
config
.
floatX
)
# positive definite matrix:
A_val
=
np
.
dot
(
A_val
.
transpose
(),
A_val
)
assert
np
.
allclose
(
scipy
.
linalg
.
solve
(
A_val
,
b_val
),
gen_solve_func
(
A_val
,
b_val
)
)
# 2-test lower traingular case
L_val
=
scipy
.
linalg
.
cholesky
(
A_val
,
lower
=
True
)
assert
np
.
allclose
(
scipy
.
linalg
.
solve_triangular
(
L_val
,
b_val
,
lower
=
True
),
lower_solve_func
(
L_val
,
b_val
),
A_undef
=
np
.
array
(
[
[
1
,
0
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
1
,
0
],
],
dtype
=
config
.
floatX
,
)
# 3-test upper traingular case
U_val
=
scipy
.
linalg
.
cholesky
(
A_val
,
lower
=
False
)
assert
np
.
allclose
(
scipy
.
linalg
.
solve_triangular
(
U_val
,
b_val
,
lower
=
False
),
upper_solve_func
(
U_val
,
b_val
),
scipy
.
linalg
.
solve
(
A_undef
,
b_val
),
gen_solve_func
(
A_undef
,
b_val
)
)
def
test_solve_dtype
(
self
):
dtypes
=
[
"uint8"
,
"uint16"
,
"uint32"
,
"uint64"
,
"int8"
,
"int16"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
]
A_val
=
np
.
eye
(
2
)
b_val
=
np
.
ones
((
2
,
1
))
# try all dtype combinations
for
A_dtype
,
b_dtype
in
itertools
.
product
(
dtypes
,
dtypes
):
A
=
matrix
(
dtype
=
A_dtype
)
b
=
matrix
(
dtype
=
b_dtype
)
x
=
solve
(
A
,
b
)
fn
=
function
([
A
,
b
],
x
)
x_result
=
fn
(
A_val
.
astype
(
A_dtype
),
b_val
.
astype
(
b_dtype
))
assert
x
.
dtype
==
x_result
.
dtype
@pytest.mark.parametrize
(
"m, n, assume_a, lower"
,
[
(
5
,
None
,
"gen"
,
False
),
(
5
,
None
,
"gen"
,
True
),
(
4
,
2
,
"gen"
,
False
),
(
4
,
2
,
"gen"
,
True
),
],
)
def
test_solve_grad
(
self
,
m
,
n
,
assume_a
,
lower
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
def
verify_solve_grad
(
self
,
m
,
n
,
assume_a
,
lower
,
rng
):
# ensure diagonal elements of A relatively large to avoid numerical
# precision issues
# Ensure diagonal elements of `A` are relatively large to avoid
# numerical precision issues
A_val
=
(
rng
.
normal
(
size
=
(
m
,
m
))
*
0.5
+
np
.
eye
(
m
))
.
astype
(
config
.
floatX
)
if
assume_a
!=
"gen"
:
if
lower
:
A_val
=
np
.
tril
(
A_val
)
else
:
A_val
=
np
.
triu
(
A_val
)
if
n
is
None
:
b_val
=
rng
.
normal
(
size
=
m
)
.
astype
(
config
.
floatX
)
else
:
...
...
@@ -298,22 +287,76 @@ class TestSolve(utt.InferShapeTester):
solve_op
=
Solve
(
assume_a
=
assume_a
,
lower
=
lower
)
utt
.
verify_grad
(
solve_op
,
[
A_val
,
b_val
],
3
,
rng
,
eps
=
eps
)
class
TestSolveTriangular
(
utt
.
InferShapeTester
):
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,
1
),
(
5
,)])
def
test_infer_shape
(
self
,
b_shape
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
matrix
()
b_val
=
np
.
asarray
(
rng
.
random
(
b_shape
),
dtype
=
config
.
floatX
)
b
=
aet
.
as_tensor_variable
(
b_val
)
.
type
()
self
.
_compile_and_check
(
[
A
,
b
],
[
solve_triangular
(
A
,
b
)],
[
np
.
asarray
(
rng
.
random
((
5
,
5
)),
dtype
=
config
.
floatX
),
b_val
,
],
SolveTriangular
,
warn
=
False
,
)
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
])
def
test_correctness
(
self
,
lower
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
b_val
=
np
.
asarray
(
rng
.
random
((
5
,
1
)),
dtype
=
config
.
floatX
)
A_val
=
np
.
asarray
(
rng
.
random
((
5
,
5
)),
dtype
=
config
.
floatX
)
A_val
=
np
.
dot
(
A_val
.
transpose
(),
A_val
)
C_val
=
scipy
.
linalg
.
cholesky
(
A_val
,
lower
=
lower
)
A
=
matrix
()
b
=
matrix
()
cholesky
=
Cholesky
(
lower
=
lower
)
C
=
cholesky
(
A
)
y_lower
=
solve_triangular
(
C
,
b
,
lower
=
lower
)
lower_solve_func
=
aesara
.
function
([
C
,
b
],
y_lower
)
assert
np
.
allclose
(
scipy
.
linalg
.
solve_triangular
(
C_val
,
b_val
,
lower
=
lower
),
lower_solve_func
(
C_val
,
b_val
),
)
@pytest.mark.parametrize
(
"m, n,
assume_a,
lower"
,
"m, n, lower"
,
[
(
5
,
None
,
"gen"
,
False
),
(
5
,
None
,
"gen"
,
True
),
(
4
,
2
,
"gen"
,
False
),
(
4
,
2
,
"gen"
,
True
),
(
5
,
None
,
"sym"
,
False
),
(
5
,
None
,
"sym"
,
True
),
(
4
,
2
,
"sym"
,
False
),
(
4
,
2
,
"sym"
,
True
),
(
5
,
None
,
False
),
(
5
,
None
,
True
),
(
4
,
2
,
False
),
(
4
,
2
,
True
),
],
)
def
test_solve_grad
(
self
,
m
,
n
,
assume_a
,
lower
):
def
test_solve_grad
(
self
,
m
,
n
,
lower
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
self
.
verify_solve_grad
(
m
,
n
,
assume_a
,
lower
,
rng
)
# Ensure diagonal elements of `A` are relatively large to avoid
# numerical precision issues
A_val
=
(
rng
.
normal
(
size
=
(
m
,
m
))
*
0.5
+
np
.
eye
(
m
))
.
astype
(
config
.
floatX
)
if
n
is
None
:
b_val
=
rng
.
normal
(
size
=
m
)
.
astype
(
config
.
floatX
)
else
:
b_val
=
rng
.
normal
(
size
=
(
m
,
n
))
.
astype
(
config
.
floatX
)
eps
=
None
if
config
.
floatX
==
"float64"
:
eps
=
2e-8
solve_op
=
SolveTriangular
(
lower
=
lower
)
utt
.
verify_grad
(
solve_op
,
[
A_val
,
b_val
],
3
,
rng
,
eps
=
eps
)
class
TestCholeskySolve
(
utt
.
InferShapeTester
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论