Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d88c7351
提交
d88c7351
authored
4月 29, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
6月 10, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Decompose Tridiagonal Solve into core steps
上级
43d8e303
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
533 行增加
和
33 行删除
+533
-33
mode.py
pytensor/compile/mode.py
+3
-0
tridiagonal.py
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
+95
-9
rewriting.py
pytensor/tensor/_linalg/solve/rewriting.py
+55
-5
tridiagonal.py
pytensor/tensor/_linalg/solve/tridiagonal.py
+228
-0
__init__.py
tests/link/numba/linalg/__init__.py
+0
-0
__init__.py
tests/link/numba/linalg/solve/__init__.py
+0
-0
test_tridiagonal.py
tests/link/numba/linalg/solve/test_tridiagonal.py
+114
-0
test_rewriting.py
tests/tensor/linalg/test_rewriting.py
+38
-19
没有找到文件。
pytensor/compile/mode.py
浏览文件 @
d88c7351
...
...
@@ -477,6 +477,9 @@ JAX = Mode(
"fusion"
,
"inplace"
,
"scan_save_mem_prealloc"
,
# There are specific variants for the LU decompositions supported by JAX
"reuse_lu_decomposition_multiple_solves"
,
"scan_split_non_sequence_lu_decomposition_solve"
,
],
),
)
...
...
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
浏览文件 @
d88c7351
...
...
@@ -6,6 +6,7 @@ from numba.np.linalg import ensure_lapack
from
numpy
import
ndarray
from
scipy
import
linalg
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
numba_njit
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
...
...
@@ -20,6 +21,10 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_solve_check
,
_trans_char_to_int
,
)
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
LUFactorTridiagonal
,
SolveLUFactorTridiagonal
,
)
@numba_njit
...
...
@@ -34,7 +39,12 @@ def tridiagonal_norm(du, d, dl):
def
_gttrf
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
overwrite_dl
:
bool
,
overwrite_d
:
bool
,
overwrite_du
:
bool
,
)
->
tuple
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
int
]:
"""Placeholder for LU factorization of tridiagonal matrix."""
return
# type: ignore
...
...
@@ -45,8 +55,12 @@ def gttrf_impl(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
overwrite_dl
:
bool
,
overwrite_d
:
bool
,
overwrite_du
:
bool
,
)
->
Callable
[
[
ndarray
,
ndarray
,
ndarray
],
tuple
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
int
]
[
ndarray
,
ndarray
,
ndarray
,
bool
,
bool
,
bool
],
tuple
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
int
],
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
dl
,
"gttrf"
)
...
...
@@ -60,12 +74,24 @@ def gttrf_impl(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
overwrite_dl
:
bool
,
overwrite_d
:
bool
,
overwrite_du
:
bool
,
)
->
tuple
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
int
]:
n
=
np
.
int32
(
d
.
shape
[
-
1
])
ipiv
=
np
.
empty
(
n
,
dtype
=
np
.
int32
)
du2
=
np
.
empty
(
n
-
2
,
dtype
=
dtype
)
info
=
val_to_int_ptr
(
0
)
if
not
overwrite_dl
or
not
dl
.
flags
.
f_contiguous
:
dl
=
dl
.
copy
()
if
not
overwrite_d
or
not
d
.
flags
.
f_contiguous
:
d
=
d
.
copy
()
if
not
overwrite_du
or
not
du
.
flags
.
f_contiguous
:
du
=
du
.
copy
()
numba_gttrf
(
val_to_int_ptr
(
n
),
dl
.
view
(
w_type
)
.
ctypes
,
...
...
@@ -133,10 +159,23 @@ def gttrs_impl(
nrhs
=
1
if
b
.
ndim
==
1
else
int
(
b
.
shape
[
-
1
])
info
=
val_to_int_ptr
(
0
)
if
overwrite_b
and
b
.
flags
.
f_contiguous
:
b_copy
=
b
else
:
b_copy
=
_copy_to_fortran_order_even_if_1d
(
b
)
if
not
overwrite_b
or
not
b
.
flags
.
f_contiguous
:
b
=
_copy_to_fortran_order_even_if_1d
(
b
)
if
not
dl
.
flags
.
f_contiguous
:
dl
=
dl
.
copy
()
if
not
d
.
flags
.
f_contiguous
:
d
=
d
.
copy
()
if
not
du
.
flags
.
f_contiguous
:
du
=
du
.
copy
()
if
not
du2
.
flags
.
f_contiguous
:
du2
=
du2
.
copy
()
if
not
ipiv
.
flags
.
f_contiguous
:
ipiv
=
ipiv
.
copy
()
numba_gttrs
(
val_to_int_ptr
(
_trans_char_to_int
(
trans
)),
...
...
@@ -147,12 +186,12 @@ def gttrs_impl(
du
.
view
(
w_type
)
.
ctypes
,
du2
.
view
(
w_type
)
.
ctypes
,
ipiv
.
ctypes
,
b
_copy
.
view
(
w_type
)
.
ctypes
,
b
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
n
),
info
,
)
return
b
_copy
,
int_ptr_to_val
(
info
)
return
b
,
int_ptr_to_val
(
info
)
return
impl
...
...
@@ -283,7 +322,9 @@ def _tridiagonal_solve_impl(
anorm
=
tridiagonal_norm
(
du
,
d
,
dl
)
dl
,
d
,
du
,
du2
,
IPIV
,
INFO
=
_gttrf
(
dl
,
d
,
du
)
dl
,
d
,
du
,
du2
,
IPIV
,
INFO
=
_gttrf
(
dl
,
d
,
du
,
overwrite_dl
=
True
,
overwrite_d
=
True
,
overwrite_du
=
True
)
_solve_check
(
n
,
INFO
)
X
,
INFO
=
_gttrs
(
...
...
@@ -297,3 +338,48 @@ def _tridiagonal_solve_impl(
return
X
return
impl
@numba_funcify.register
(
LUFactorTridiagonal
)
def
numba_funcify_LUFactorTridiagonal
(
op
:
LUFactorTridiagonal
,
node
,
**
kwargs
):
overwrite_dl
=
op
.
overwrite_dl
overwrite_d
=
op
.
overwrite_d
overwrite_du
=
op
.
overwrite_du
@numba_njit
(
cache
=
False
)
def
lu_factor_tridiagonal
(
dl
,
d
,
du
):
dl
,
d
,
du
,
du2
,
ipiv
,
_
=
_gttrf
(
dl
,
d
,
du
,
overwrite_dl
=
overwrite_dl
,
overwrite_d
=
overwrite_d
,
overwrite_du
=
overwrite_du
,
)
return
dl
,
d
,
du
,
du2
,
ipiv
return
lu_factor_tridiagonal
@numba_funcify.register
(
SolveLUFactorTridiagonal
)
def
numba_funcify_SolveLUFactorTridiagonal
(
op
:
SolveLUFactorTridiagonal
,
node
,
**
kwargs
):
overwrite_b
=
op
.
overwrite_b
transposed
=
op
.
transposed
@numba_njit
(
cache
=
False
)
def
solve_lu_factor_tridiagonal
(
dl
,
d
,
du
,
du2
,
ipiv
,
b
):
x
,
_
=
_gttrs
(
dl
,
d
,
du
,
du2
,
ipiv
,
b
,
overwrite_b
=
overwrite_b
,
trans
=
transposed
,
)
return
x
return
solve_lu_factor_tridiagonal
pytensor/tensor/_linalg/solve/rewriting.py
浏览文件 @
d88c7351
from
collections.abc
import
Container
from
copy
import
copy
from
pytensor.compile
import
optdb
from
pytensor.graph
import
Constant
,
graph_inputs
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
,
node_rewriter
from
pytensor.scan.op
import
Scan
from
pytensor.scan.rewriting
import
scan_seqopt1
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
tridiagonal_lu_factor
,
tridiagonal_lu_solve
,
)
from
pytensor.tensor.basic
import
atleast_Nd
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
...
...
@@ -17,18 +22,32 @@ from pytensor.tensor.variable import TensorVariable
def
decompose_A
(
A
,
assume_a
,
check_finite
):
if
assume_a
==
"gen"
:
return
lu_factor
(
A
,
check_finite
=
check_finite
)
elif
assume_a
==
"tridiagonal"
:
# We didn't implement check_finite for tridiagonal LU factorization
return
tridiagonal_lu_factor
(
A
)
else
:
raise
NotImplementedError
def
solve_lu_decomposed_system
(
A_decomp
,
b
,
transposed
=
False
,
*
,
core_solve_op
:
Solve
):
if
core_solve_op
.
assume_a
==
"gen"
:
b_ndim
=
core_solve_op
.
b_ndim
check_finite
=
core_solve_op
.
check_finite
assume_a
=
core_solve_op
.
assume_a
if
assume_a
==
"gen"
:
return
lu_solve
(
A_decomp
,
b
,
b_ndim
=
b_ndim
,
trans
=
transposed
,
b_ndim
=
core_solve_op
.
b_ndim
,
check_finite
=
core_solve_op
.
check_finite
,
check_finite
=
check_finite
,
)
elif
assume_a
==
"tridiagonal"
:
# We didn't implement check_finite for tridiagonal LU solve
return
tridiagonal_lu_solve
(
A_decomp
,
b
,
b_ndim
=
b_ndim
,
transposed
=
transposed
,
)
else
:
raise
NotImplementedError
...
...
@@ -189,13 +208,15 @@ def _scan_split_non_sequence_lu_decomposition_solve(
@register_specialize
@node_rewriter
([
Blockwise
])
def
reuse_lu_decomposition_multiple_solves
(
fgraph
,
node
):
return
_split_lu_solve_steps
(
fgraph
,
node
,
eager
=
False
,
allowed_assume_a
=
{
"gen"
})
return
_split_lu_solve_steps
(
fgraph
,
node
,
eager
=
False
,
allowed_assume_a
=
{
"gen"
,
"tridiagonal"
}
)
@node_rewriter
([
Scan
])
def
scan_split_non_sequence_lu_decomposition_solve
(
fgraph
,
node
):
return
_scan_split_non_sequence_lu_decomposition_solve
(
fgraph
,
node
,
allowed_assume_a
=
{
"gen"
}
fgraph
,
node
,
allowed_assume_a
=
{
"gen"
,
"tridiagonal"
}
)
...
...
@@ -207,3 +228,32 @@ scan_seqopt1.register(
"scan_pushout"
,
position
=
2
,
)
@node_rewriter
([
Blockwise
])
def
reuse_lu_decomposition_multiple_solves_jax
(
fgraph
,
node
):
return
_split_lu_solve_steps
(
fgraph
,
node
,
eager
=
False
,
allowed_assume_a
=
{
"gen"
})
optdb
[
"specialize"
]
.
register
(
reuse_lu_decomposition_multiple_solves_jax
.
__name__
,
in2out
(
reuse_lu_decomposition_multiple_solves_jax
,
ignore_newtrees
=
True
),
"jax"
,
use_db_name_as_tag
=
False
,
)
@node_rewriter
([
Scan
])
def
scan_split_non_sequence_lu_decomposition_solve_jax
(
fgraph
,
node
):
return
_scan_split_non_sequence_lu_decomposition_solve
(
fgraph
,
node
,
allowed_assume_a
=
{
"gen"
}
)
scan_seqopt1
.
register
(
scan_split_non_sequence_lu_decomposition_solve_jax
.
__name__
,
in2out
(
scan_split_non_sequence_lu_decomposition_solve_jax
,
ignore_newtrees
=
True
),
"jax"
,
use_db_name_as_tag
=
False
,
position
=
2
,
)
pytensor/tensor/_linalg/solve/tridiagonal.py
0 → 100644
浏览文件 @
d88c7351
import
typing
from
typing
import
TYPE_CHECKING
import
numpy
as
np
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.graph
import
Apply
,
Op
from
pytensor.tensor.basic
import
as_tensor
,
diagonal
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.type
import
tensor
,
vector
from
pytensor.tensor.variable
import
TensorVariable
if
TYPE_CHECKING
:
from
pytensor.tensor
import
TensorLike
class
LUFactorTridiagonal
(
Op
):
"""Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
__props__
=
(
"overwrite_dl"
,
"overwrite_d"
,
"overwrite_du"
,
)
gufunc_signature
=
"(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
def
__init__
(
self
,
overwrite_dl
=
False
,
overwrite_d
=
False
,
overwrite_du
=
False
):
self
.
destroy_map
=
dm
=
{}
if
overwrite_dl
:
dm
[
0
]
=
[
0
]
if
overwrite_d
:
dm
[
1
]
=
[
1
]
if
overwrite_du
:
dm
[
2
]
=
[
2
]
self
.
overwrite_dl
=
overwrite_dl
self
.
overwrite_d
=
overwrite_d
self
.
overwrite_du
=
overwrite_du
super
()
.
__init__
()
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
return
type
(
self
)(
overwrite_dl
=
0
in
allowed_inplace_inputs
,
overwrite_d
=
1
in
allowed_inplace_inputs
,
overwrite_du
=
2
in
allowed_inplace_inputs
,
)
def
make_node
(
self
,
dl
,
d
,
du
):
dl
,
d
,
du
=
map
(
as_tensor
,
(
dl
,
d
,
du
))
if
not
all
(
inp
.
type
.
ndim
==
1
for
inp
in
(
dl
,
d
,
du
)):
raise
ValueError
(
"Diagonals must be vectors"
)
ndl
,
nd
,
ndu
=
(
inp
.
type
.
shape
[
-
1
]
for
inp
in
(
dl
,
d
,
du
))
match
(
ndl
,
nd
,
ndu
):
case
(
int
(),
_
,
_
):
n
=
ndl
+
1
case
(
_
,
int
(),
_
):
n
=
nd
+
1
case
(
_
,
_
,
int
()):
n
=
ndu
+
1
case
_
:
n
=
None
dummy_arrays
=
[
np
.
zeros
((),
dtype
=
inp
.
type
.
dtype
)
for
inp
in
(
dl
,
d
,
du
)]
out_dtype
=
get_lapack_funcs
(
"gttrf"
,
dummy_arrays
)
.
dtype
outputs
=
[
vector
(
shape
=
(
None
if
n
is
None
else
(
n
-
1
),),
dtype
=
out_dtype
),
vector
(
shape
=
(
n
,),
dtype
=
out_dtype
),
vector
(
shape
=
(
None
if
n
is
None
else
n
-
1
,),
dtype
=
out_dtype
),
vector
(
shape
=
(
None
if
n
is
None
else
n
-
2
,),
dtype
=
out_dtype
),
vector
(
shape
=
(
n
,),
dtype
=
np
.
int32
),
]
return
Apply
(
self
,
[
dl
,
d
,
du
],
outputs
)
def
perform
(
self
,
node
,
inputs
,
output_storage
):
gttrf
=
get_lapack_funcs
(
"gttrf"
,
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
)
dl
,
d
,
du
,
du2
,
ipiv
,
_
=
gttrf
(
*
inputs
,
overwrite_dl
=
self
.
overwrite_dl
,
overwrite_d
=
self
.
overwrite_d
,
overwrite_du
=
self
.
overwrite_du
,
)
output_storage
[
0
][
0
]
=
dl
output_storage
[
1
][
0
]
=
d
output_storage
[
2
][
0
]
=
du
output_storage
[
3
][
0
]
=
du2
output_storage
[
4
][
0
]
=
ipiv
class
SolveLUFactorTridiagonal
(
Op
):
"""Solve a system of linear equations with a tridiagonal coefficient matrix (lapack gttrs)."""
__props__
=
(
"b_ndim"
,
"overwrite_b"
,
"transposed"
)
def
__init__
(
self
,
b_ndim
:
int
,
transposed
:
bool
,
overwrite_b
=
False
):
if
b_ndim
not
in
(
1
,
2
):
raise
ValueError
(
"b_ndim must be 1 or 2"
)
if
b_ndim
==
1
:
self
.
gufunc_signature
=
"(dl),(d),(dl),(du2),(d),(d)->(d)"
else
:
self
.
gufunc_signature
=
"(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)"
if
overwrite_b
:
self
.
destroy_map
=
{
0
:
[
5
]}
self
.
b_ndim
=
b_ndim
self
.
transposed
=
transposed
self
.
overwrite_b
=
overwrite_b
super
()
.
__init__
()
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
# b matrix is the 5th input
if
5
in
allowed_inplace_inputs
:
props
=
self
.
_props_dict
()
# type: ignore
props
[
"overwrite_b"
]
=
True
return
type
(
self
)(
**
props
)
return
self
def
make_node
(
self
,
dl
,
d
,
du
,
du2
,
ipiv
,
b
):
dl
,
d
,
du
,
du2
,
ipiv
,
b
=
map
(
as_tensor
,
(
dl
,
d
,
du
,
du2
,
ipiv
,
b
))
if
b
.
type
.
ndim
!=
self
.
b_ndim
:
raise
ValueError
(
"Wrong number of dimensions for input b."
)
if
not
all
(
inp
.
type
.
ndim
==
1
for
inp
in
(
dl
,
d
,
du
,
du2
,
ipiv
)):
raise
ValueError
(
"Inputs must be vectors"
)
ndl
,
nd
,
ndu
,
ndu2
,
nipiv
=
(
inp
.
type
.
shape
[
-
1
]
for
inp
in
(
dl
,
d
,
du
,
du2
,
ipiv
)
)
nb
=
b
.
type
.
shape
[
0
]
match
(
ndl
,
nd
,
ndu
,
ndu2
,
nipiv
):
case
(
int
(),
_
,
_
,
_
,
_
):
n
=
ndl
+
1
case
(
_
,
int
(),
_
,
_
,
_
):
n
=
nd
case
(
_
,
_
,
int
(),
_
,
_
):
n
=
ndu
+
1
case
(
_
,
_
,
_
,
int
(),
_
):
n
=
ndu2
+
2
case
(
_
,
_
,
_
,
_
,
int
()):
n
=
nipiv
case
_
:
n
=
nb
dummy_arrays
=
[
np
.
zeros
((),
dtype
=
inp
.
type
.
dtype
)
for
inp
in
(
dl
,
d
,
du
,
du2
,
ipiv
)
]
# Seems to always be float64?
out_dtype
=
get_lapack_funcs
(
"gttrs"
,
dummy_arrays
)
.
dtype
if
self
.
b_ndim
==
1
:
output_shape
=
(
n
,)
else
:
output_shape
=
(
n
,
b
.
type
.
shape
[
-
1
])
outputs
=
[
tensor
(
shape
=
output_shape
,
dtype
=
out_dtype
)]
return
Apply
(
self
,
[
dl
,
d
,
du
,
du2
,
ipiv
,
b
],
outputs
)
def
perform
(
self
,
node
,
inputs
,
output_storage
):
gttrs
=
get_lapack_funcs
(
"gttrs"
,
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
)
x
,
_
=
gttrs
(
*
inputs
,
overwrite_b
=
self
.
overwrite_b
,
trans
=
"N"
if
not
self
.
transposed
else
"T"
,
)
output_storage
[
0
][
0
]
=
x
def
tridiagonal_lu_factor
(
a
:
"TensorLike"
,
)
->
tuple
[
TensorVariable
,
TensorVariable
,
TensorVariable
,
TensorVariable
,
TensorVariable
]:
"""Return the decomposition of A implied by a solve tridiagonal (LAPACK's gttrf)
Parameters
----------
a
The input matrix.
Returns
-------
dl, d, du, du2, ipiv
The LU factorization of A.
"""
dl
,
d
,
du
=
(
diagonal
(
a
,
offset
=
o
,
axis1
=-
2
,
axis2
=-
1
)
for
o
in
(
-
1
,
0
,
1
))
dl
,
d
,
du
,
du2
,
ipiv
=
typing
.
cast
(
list
[
TensorVariable
],
Blockwise
(
LUFactorTridiagonal
())(
dl
,
d
,
du
)
)
return
dl
,
d
,
du
,
du2
,
ipiv
def
tridiagonal_lu_solve
(
a_diagonals
:
tuple
[
"TensorLike"
,
"TensorLike"
,
"TensorLike"
,
"TensorLike"
,
"TensorLike"
],
b
:
"TensorLike"
,
*
,
b_ndim
:
int
,
transposed
:
bool
=
False
,
)
->
TensorVariable
:
"""Solve a tridiagonal system of equations using LU factorized inputs (LAPACK's gttrs).
Parameters
----------
a_diagonals
The outputs of tridiagonal_lu_factor(A).
b
The right-hand side vector or matrix.
b_ndim
The number of dimensions of the right-hand side.
transposed
Whether to solve the transposed system.
Returns
-------
TensorVariable
The solution vector or matrix.
"""
dl
,
d
,
du
,
du2
,
ipiv
=
a_diagonals
return
typing
.
cast
(
TensorVariable
,
Blockwise
(
SolveLUFactorTridiagonal
(
b_ndim
=
b_ndim
,
transposed
=
transposed
))(
dl
,
d
,
du
,
du2
,
ipiv
,
b
),
)
tests/link/numba/linalg/__init__.py
0 → 100644
浏览文件 @
d88c7351
tests/link/numba/linalg/solve/__init__.py
0 → 100644
浏览文件 @
d88c7351
tests/link/numba/linalg/solve/test_tridiagonal.py
0 → 100644
浏览文件 @
d88c7351
import
numpy
as
np
import
pytest
import
scipy
from
pytensor
import
In
from
pytensor
import
tensor
as
pt
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
LUFactorTridiagonal
,
SolveLUFactorTridiagonal
,
)
from
pytensor.tensor.blockwise
import
Blockwise
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
numba_inplace_mode
@pytest.mark.parametrize
(
"inplace"
,
[
False
,
True
],
ids
=
lambda
x
:
f
"inplace={x}"
)
def
test_tridiagonal_lu_factor
(
inplace
):
dl
=
pt
.
vector
(
"dl"
,
shape
=
(
4
,))
d
=
pt
.
vector
(
"d"
,
shape
=
(
5
,))
du
=
pt
.
vector
(
"du"
,
shape
=
(
4
,))
lu_factor_outs
=
Blockwise
(
LUFactorTridiagonal
())(
dl
,
d
,
du
)
rng
=
np
.
random
.
default_rng
(
734
)
dl_test
=
rng
.
random
(
dl
.
type
.
shape
)
d_test
=
rng
.
random
(
d
.
type
.
shape
)
du_test
=
rng
.
random
(
du
.
type
.
shape
)
f
,
results
=
compare_numba_and_py
(
[
In
(
dl
,
mutable
=
inplace
),
In
(
d
,
mutable
=
inplace
),
In
(
du
,
mutable
=
inplace
),
],
lu_factor_outs
,
test_inputs
=
[
dl_test
,
d_test
,
du_test
],
inplace
=
True
,
numba_mode
=
numba_inplace_mode
,
eval_obj_mode
=
False
,
)
# Test with contiguous inputs
dl_test_contig
=
dl_test
.
copy
()
d_test_contig
=
d_test
.
copy
()
du_test_contig
=
du_test
.
copy
()
results_contig
=
f
(
dl_test_contig
,
d_test_contig
,
du_test_contig
)
for
res
,
res_contig
in
zip
(
results
,
results_contig
):
np
.
testing
.
assert_allclose
(
res
,
res_contig
)
assert
(
dl_test_contig
==
dl_test
)
.
all
()
==
(
not
inplace
)
assert
(
d_test_contig
==
d_test
)
.
all
()
==
(
not
inplace
)
assert
(
du_test_contig
==
du_test
)
.
all
()
==
(
not
inplace
)
# Test with non-contiguous inputs
dl_test_not_contig
=
np
.
repeat
(
dl_test
,
2
)[::
2
]
d_test_not_contig
=
np
.
repeat
(
d_test
,
2
)[::
2
]
du_test_not_contig
=
np
.
repeat
(
du_test
,
2
)[::
2
]
results_not_contig
=
f
(
dl_test_not_contig
,
d_test_not_contig
,
du_test_not_contig
)
for
res
,
res_not_contig
in
zip
(
results
,
results_not_contig
):
np
.
testing
.
assert_allclose
(
res
,
res_not_contig
)
# Non-contiguous inputs have to be copied so are not modified in place
assert
(
dl_test_not_contig
==
dl_test
)
.
all
()
assert
(
d_test_not_contig
==
d_test
)
.
all
()
assert
(
du_test_not_contig
==
du_test
)
.
all
()
@pytest.mark.parametrize
(
"transposed"
,
[
False
,
True
],
ids
=
lambda
x
:
f
"transposed={x}"
)
@pytest.mark.parametrize
(
"inplace"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"inplace={x}"
)
@pytest.mark.parametrize
(
"b_ndim"
,
[
1
,
2
],
ids
=
lambda
x
:
f
"b_ndim={x}"
)
def
test_tridiagonal_lu_solve
(
b_ndim
,
transposed
,
inplace
):
scipy_gttrf
=
scipy
.
linalg
.
get_lapack_funcs
(
"gttrf"
)
dl
=
pt
.
tensor
(
"dl"
,
shape
=
(
9
,))
d
=
pt
.
tensor
(
"d"
,
shape
=
(
10
,))
du
=
pt
.
tensor
(
"du"
,
shape
=
(
9
,))
du2
=
pt
.
tensor
(
"du2"
,
shape
=
(
8
,))
ipiv
=
pt
.
tensor
(
"ipiv"
,
shape
=
(
10
,),
dtype
=
"int32"
)
diagonals
=
[
dl
,
d
,
du
,
du2
,
ipiv
]
b
=
pt
.
tensor
(
"b"
,
shape
=
(
10
,
25
)[:
b_ndim
])
x
=
Blockwise
(
SolveLUFactorTridiagonal
(
b_ndim
=
b
.
type
.
ndim
,
transposed
=
transposed
))(
*
diagonals
,
b
)
rng
=
np
.
random
.
default_rng
(
787
)
A_test
=
rng
.
random
((
d
.
type
.
shape
[
0
],
d
.
type
.
shape
[
0
]))
*
diagonals_test
,
_
=
scipy_gttrf
(
*
(
np
.
diagonal
(
A_test
,
offset
=
o
)
for
o
in
(
-
1
,
0
,
1
))
)
b_test
=
rng
.
random
(
b
.
type
.
shape
)
f
,
res
=
compare_numba_and_py
(
[
*
diagonals
,
In
(
b
,
mutable
=
inplace
),
],
x
,
test_inputs
=
[
*
diagonals_test
,
b_test
],
inplace
=
True
,
numba_mode
=
numba_inplace_mode
,
eval_obj_mode
=
False
,
)
# Test with contiguous_inputs
diagonals_test_contig
=
[
d_test
.
copy
()
for
d_test
in
diagonals_test
]
b_test_contig
=
b_test
.
copy
(
order
=
"F"
)
res_contig
=
f
(
*
diagonals_test_contig
,
b_test_contig
)
assert
(
res_contig
==
res
)
.
all
()
assert
(
b_test
==
b_test_contig
)
.
all
()
==
(
not
inplace
)
# Test with non-contiguous inputs
diagonals_test_non_contig
=
[
np
.
repeat
(
d_test
,
2
)[::
2
]
for
d_test
in
diagonals_test
]
b_test_non_contig
=
np
.
repeat
(
b_test
,
2
,
axis
=
0
)[::
2
]
res_non_contig
=
f
(
*
diagonals_test_non_contig
,
b_test_non_contig
)
assert
(
res_non_contig
==
res
)
.
all
()
# b must be copied when not contiguous so it can't be inplaced
assert
(
b_test
==
b_test_non_contig
)
.
all
()
tests/tensor/linalg/test_rewriting.py
浏览文件 @
d88c7351
...
...
@@ -9,6 +9,10 @@ from pytensor.tensor._linalg.solve.rewriting import (
reuse_lu_decomposition_multiple_solves
,
scan_split_non_sequence_lu_decomposition_solve
,
)
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
LUFactorTridiagonal
,
SolveLUFactorTridiagonal
,
)
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.linalg
import
solve
from
pytensor.tensor.slinalg
import
LUFactor
,
Solve
,
SolveTriangular
...
...
@@ -28,9 +32,10 @@ def count_vanilla_solve_nodes(nodes) -> int:
def
count_lu_decom_nodes
(
nodes
)
->
int
:
return
sum
(
(
isinstance
(
node
.
op
,
LUFactor
)
isinstance
(
node
.
op
,
LUFactor
|
LUFactorTridiagonal
)
or
(
isinstance
(
node
.
op
,
Blockwise
)
and
isinstance
(
node
.
op
.
core_op
,
LUFactor
)
isinstance
(
node
.
op
,
Blockwise
)
and
isinstance
(
node
.
op
.
core_op
,
LUFactor
|
LUFactorTridiagonal
)
)
)
for
node
in
nodes
...
...
@@ -40,27 +45,38 @@ def count_lu_decom_nodes(nodes) -> int:
def
count_lu_solve_nodes
(
nodes
)
->
int
:
count
=
sum
(
(
# LUFactor uses 2 SolveTriangular nodes, so we count each as 0.5
0.5
*
(
isinstance
(
node
.
op
,
SolveTriangular
)
or
(
isinstance
(
node
.
op
,
Blockwise
)
and
isinstance
(
node
.
op
.
core_op
,
SolveTriangular
)
)
)
or
(
isinstance
(
node
.
op
,
SolveLUFactorTridiagonal
)
or
(
isinstance
(
node
.
op
,
Blockwise
)
and
isinstance
(
node
.
op
.
core_op
,
SolveLUFactorTridiagonal
)
)
)
)
for
node
in
nodes
)
# Each LU solve uses two Triangular solves
return
count
//
2
return
int
(
count
)
@pytest.mark.parametrize
(
"transposed"
,
(
False
,
True
))
def
test_lu_decomposition_reused_forward_and_gradient
(
transposed
):
@pytest.mark.parametrize
(
"assume_a"
,
(
"gen"
,
"tridiagonal"
))
def
test_lu_decomposition_reused_forward_and_gradient
(
assume_a
,
transposed
):
rewrite_name
=
reuse_lu_decomposition_multiple_solves
.
__name__
mode
=
get_default_mode
()
A
=
tensor
(
"A"
,
shape
=
(
2
,
2
))
b
=
tensor
(
"b"
,
shape
=
(
2
,
3
))
A
=
tensor
(
"A"
,
shape
=
(
3
,
3
))
b
=
tensor
(
"b"
,
shape
=
(
3
,
4
))
x
=
solve
(
A
,
b
,
assume_a
=
"gen"
,
transposed
=
transposed
)
x
=
solve
(
A
,
b
,
assume_a
=
assume_a
,
transposed
=
transposed
)
grad_x_wrt_A
=
grad
(
x
.
sum
(),
A
)
fn_no_opt
=
function
([
A
,
b
],
[
x
,
grad_x_wrt_A
],
mode
=
mode
.
excluding
(
rewrite_name
))
no_opt_nodes
=
fn_no_opt
.
maker
.
fgraph
.
apply_nodes
...
...
@@ -80,20 +96,21 @@ def test_lu_decomposition_reused_forward_and_gradient(transposed):
b_test
=
rng
.
random
(
b
.
type
.
shape
,
dtype
=
b
.
type
.
dtype
)
resx0
,
resg0
=
fn_no_opt
(
A_test
,
b_test
)
resx1
,
resg1
=
fn_opt
(
A_test
,
b_test
)
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-
6
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-
4
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
resg0
,
resg1
,
rtol
=
rtol
)
@pytest.mark.parametrize
(
"transposed"
,
(
False
,
True
))
def
test_lu_decomposition_reused_blockwise
(
transposed
):
@pytest.mark.parametrize
(
"assume_a"
,
(
"gen"
,
"tridiagonal"
))
def
test_lu_decomposition_reused_blockwise
(
assume_a
,
transposed
):
rewrite_name
=
reuse_lu_decomposition_multiple_solves
.
__name__
mode
=
get_default_mode
()
A
=
tensor
(
"A"
,
shape
=
(
2
,
2
))
b
=
tensor
(
"b"
,
shape
=
(
2
,
2
,
3
))
A
=
tensor
(
"A"
,
shape
=
(
3
,
3
))
b
=
tensor
(
"b"
,
shape
=
(
2
,
3
,
4
))
x
=
solve
(
A
,
b
,
transposed
=
transposed
)
x
=
solve
(
A
,
b
,
assume_a
=
assume_a
,
transposed
=
transposed
)
fn_no_opt
=
function
([
A
,
b
],
[
x
],
mode
=
mode
.
excluding
(
rewrite_name
))
no_opt_nodes
=
fn_no_opt
.
maker
.
fgraph
.
apply_nodes
assert
count_vanilla_solve_nodes
(
no_opt_nodes
)
==
1
...
...
@@ -112,19 +129,21 @@ def test_lu_decomposition_reused_blockwise(transposed):
b_test
=
rng
.
random
(
b
.
type
.
shape
,
dtype
=
b
.
type
.
dtype
)
resx0
=
fn_no_opt
(
A_test
,
b_test
)
resx1
=
fn_opt
(
A_test
,
b_test
)
np
.
testing
.
assert_allclose
(
resx0
,
resx1
)
rtol
=
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-4
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
@pytest.mark.parametrize
(
"transposed"
,
(
False
,
True
))
def
test_lu_decomposition_reused_scan
(
transposed
):
@pytest.mark.parametrize
(
"assume_a"
,
(
"gen"
,
"tridiagonal"
))
def
test_lu_decomposition_reused_scan
(
assume_a
,
transposed
):
rewrite_name
=
scan_split_non_sequence_lu_decomposition_solve
.
__name__
mode
=
get_default_mode
()
A
=
tensor
(
"A"
,
shape
=
(
2
,
2
))
x0
=
tensor
(
"b"
,
shape
=
(
2
,
3
))
A
=
tensor
(
"A"
,
shape
=
(
3
,
3
))
x0
=
tensor
(
"b"
,
shape
=
(
3
,
4
))
xs
,
_
=
scan
(
lambda
xtm1
,
A
:
solve
(
A
,
xtm1
,
assume_a
=
"general"
,
transposed
=
transposed
),
lambda
xtm1
,
A
:
solve
(
A
,
xtm1
,
assume_a
=
assume_a
,
transposed
=
transposed
),
outputs_info
=
[
x0
],
non_sequences
=
[
A
],
n_steps
=
10
,
...
...
@@ -159,7 +178,7 @@ def test_lu_decomposition_reused_scan(transposed):
x0_test
=
rng
.
random
(
x0
.
type
.
shape
,
dtype
=
x0
.
type
.
dtype
)
resx0
=
fn_no_opt
(
A_test
,
x0_test
)
resx1
=
fn_opt
(
A_test
,
x0_test
)
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-
6
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-
4
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论