Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9df35e8d
提交
9df35e8d
authored
2月 02, 2024
作者:
jessegrabowski
提交者:
Ricardo Vieira
4月 28, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add rewrite to lift linear algebra through certain linalg ops
上级
d34760d7
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
149 行增加
和
4 行删除
+149
-4
builders.py
pytensor/compile/builders.py
+1
-1
nlinalg.py
pytensor/tensor/nlinalg.py
+9
-1
linalg.py
pytensor/tensor/rewriting/linalg.py
+73
-1
test_linalg.py
tests/tensor/rewriting/test_linalg.py
+58
-1
test_nlinalg.py
tests/tensor/test_nlinalg.py
+8
-0
没有找到文件。
pytensor/compile/builders.py
浏览文件 @
9df35e8d
...
...
@@ -7,7 +7,7 @@ from functools import partial
from
typing
import
cast
import
pytensor.tensor
as
pt
from
pytensor
import
function
from
pytensor
.compile.function
import
function
from
pytensor.compile.function.pfunc
import
rebuild_collect_shared
from
pytensor.compile.mode
import
optdb
from
pytensor.compile.sharedvalue
import
SharedVariable
...
...
pytensor/tensor/nlinalg.py
浏览文件 @
9df35e8d
...
...
@@ -7,6 +7,7 @@ import numpy as np
from
numpy.core.numeric
import
normalize_axis_tuple
# type: ignore
from
pytensor
import
scalar
as
ps
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
...
...
@@ -1011,6 +1012,12 @@ def tensorsolve(a, b, axes=None):
return
TensorSolve
(
axes
)(
a
,
b
)
class
KroneckerProduct
(
OpFromGraph
):
"""
Wrapper Op for Kronecker graphs
"""
def
kron
(
a
,
b
):
"""Kronecker product.
...
...
@@ -1042,7 +1049,8 @@ def kron(a, b):
out_shape
=
tuple
(
a
.
shape
*
b
.
shape
)
output_out_of_shape
=
a_reshaped
*
b_reshaped
output_reshaped
=
output_out_of_shape
.
reshape
(
out_shape
)
return
output_reshaped
return
KroneckerProduct
(
inputs
=
[
a
,
b
],
outputs
=
[
output_reshaped
])(
a
,
b
)
__all__
=
[
...
...
pytensor/tensor/rewriting/linalg.py
浏览文件 @
9df35e8d
import
logging
from
collections.abc
import
Callable
from
typing
import
cast
from
pytensor
import
Variable
from
pytensor.graph
import
Apply
,
FunctionGraph
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
node_rewriter
from
pytensor.tensor.basic
import
TensorVariable
,
diagonal
from
pytensor.tensor.blas
import
Dot22
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
Dot
,
Prod
,
_matrix_matrix_matmul
,
log
,
prod
from
pytensor.tensor.nlinalg
import
MatrixInverse
,
det
from
pytensor.tensor.nlinalg
import
(
KroneckerProduct
,
MatrixInverse
,
MatrixPinv
,
det
,
inv
,
kron
,
pinv
,
)
from
pytensor.tensor.rewriting.basic
import
(
register_canonicalize
,
register_specialize
,
register_stabilize
,
)
from
pytensor.tensor.slinalg
import
(
BlockDiagonal
,
Cholesky
,
Solve
,
SolveBase
,
block_diag
,
cholesky
,
solve
,
solve_triangular
,
...
...
@@ -305,3 +318,62 @@ def local_log_prod_sqr(fgraph, node):
# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.
@register_specialize
@node_rewriter
([
Blockwise
])
def
local_lift_through_linalg
(
fgraph
:
FunctionGraph
,
node
:
Apply
)
->
list
[
Variable
]
|
None
:
"""
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
that join matrices (KroneckerProduct, BlockDiagonal).
This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix
operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker
product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This
reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component
matrices.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# TODO: Simplify this if we end up Blockwising KroneckerProduct
if
isinstance
(
node
.
op
.
core_op
,
MatrixInverse
|
Cholesky
|
MatrixPinv
):
y
=
node
.
inputs
[
0
]
outer_op
=
node
.
op
if
y
.
owner
and
(
isinstance
(
y
.
owner
.
op
,
Blockwise
)
and
isinstance
(
y
.
owner
.
op
.
core_op
,
BlockDiagonal
)
or
isinstance
(
y
.
owner
.
op
,
KroneckerProduct
)
):
input_matrices
=
y
.
owner
.
inputs
if
isinstance
(
outer_op
.
core_op
,
MatrixInverse
):
outer_f
=
cast
(
Callable
,
inv
)
elif
isinstance
(
outer_op
.
core_op
,
Cholesky
):
outer_f
=
cast
(
Callable
,
cholesky
)
elif
isinstance
(
outer_op
.
core_op
,
MatrixPinv
):
outer_f
=
cast
(
Callable
,
pinv
)
else
:
raise
NotImplementedError
# pragma: no cover
inner_matrices
=
[
cast
(
TensorVariable
,
outer_f
(
m
))
for
m
in
input_matrices
]
if
isinstance
(
y
.
owner
.
op
,
KroneckerProduct
):
return
[
kron
(
*
inner_matrices
)]
elif
isinstance
(
y
.
owner
.
op
.
core_op
,
BlockDiagonal
):
return
[
block_diag
(
*
inner_matrices
)]
else
:
raise
NotImplementedError
# pragma: no cover
tests/tensor/rewriting/test_linalg.py
浏览文件 @
9df35e8d
...
...
@@ -14,9 +14,16 @@ from pytensor.tensor import swapaxes
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
_allclose
,
dot
,
matmul
from
pytensor.tensor.nlinalg
import
Det
,
MatrixInverse
,
matrix_inverse
from
pytensor.tensor.nlinalg
import
(
Det
,
KroneckerProduct
,
MatrixInverse
,
MatrixPinv
,
matrix_inverse
,
)
from
pytensor.tensor.rewriting.linalg
import
inv_as_solve
from
pytensor.tensor.slinalg
import
(
BlockDiagonal
,
Cholesky
,
Solve
,
SolveBase
,
...
...
@@ -333,3 +340,53 @@ class TestBatchedVectorBSolveToMatrixBSolve:
ref_fn
(
test_a
,
test_b
),
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-5
,
)
@pytest.mark.parametrize
(
"constructor"
,
[
pt
.
dmatrix
,
pt
.
tensor3
],
ids
=
[
"not_batched"
,
"batched"
]
)
@pytest.mark.parametrize
(
"f_op, f"
,
[
(
MatrixInverse
,
pt
.
linalg
.
inv
),
(
Cholesky
,
pt
.
linalg
.
cholesky
),
(
MatrixPinv
,
pt
.
linalg
.
pinv
),
],
ids
=
[
"inv"
,
"cholesky"
,
"pinv"
],
)
@pytest.mark.parametrize
(
"g_op, g"
,
[(
BlockDiagonal
,
pt
.
linalg
.
block_diag
),
(
KroneckerProduct
,
pt
.
linalg
.
kron
)],
ids
=
[
"block_diag"
,
"kron"
],
)
def
test_local_lift_through_linalg
(
constructor
,
f_op
,
f
,
g_op
,
g
):
if
pytensor
.
config
.
floatX
.
endswith
(
"32"
):
pytest
.
skip
(
"Test is flaky at half precision"
)
A
,
B
=
list
(
map
(
constructor
,
"ab"
))
X
=
f
(
g
(
A
,
B
))
f1
=
pytensor
.
function
(
[
A
,
B
],
X
,
mode
=
get_default_mode
()
.
including
(
"local_lift_through_linalg"
)
)
f2
=
pytensor
.
function
(
[
A
,
B
],
X
,
mode
=
get_default_mode
()
.
excluding
(
"local_lift_through_linalg"
)
)
all_apply_nodes
=
f1
.
maker
.
fgraph
.
apply_nodes
f_ops
=
[
x
for
x
in
all_apply_nodes
if
isinstance
(
getattr
(
x
.
op
,
"core_op"
,
x
.
op
),
f_op
)
]
g_ops
=
[
x
for
x
in
all_apply_nodes
if
isinstance
(
getattr
(
x
.
op
,
"core_op"
,
x
.
op
),
g_op
)
]
assert
len
(
f_ops
)
==
2
assert
len
(
g_ops
)
==
1
test_vals
=
[
np
.
random
.
normal
(
size
=
(
3
,)
*
A
.
ndim
)
.
astype
(
config
.
floatX
)
for
_
in
range
(
2
)
]
test_vals
=
[
x
@
np
.
swapaxes
(
x
,
-
1
,
-
2
)
for
x
in
test_vals
]
np
.
testing
.
assert_allclose
(
f1
(
*
test_vals
),
f2
(
*
test_vals
),
atol
=
1e-8
)
tests/tensor/test_nlinalg.py
浏览文件 @
9df35e8d
...
...
@@ -590,6 +590,14 @@ class TestKron(utt.InferShapeTester):
self
.
op
=
kron
super
()
.
setup_method
()
def
test_vec_vec_kron_raises
(
self
):
x
=
vector
()
y
=
vector
()
with
pytest
.
raises
(
TypeError
,
match
=
"kron: inputs dimensions must sum to 3 or more"
):
kron
(
x
,
y
)
@pytest.mark.parametrize
(
"shp0"
,
[(
2
,),
(
2
,
3
),
(
2
,
3
,
4
),
(
2
,
3
,
4
,
5
)])
@pytest.mark.parametrize
(
"shp1"
,
[(
6
,),
(
6
,
7
),
(
6
,
7
,
8
),
(
6
,
7
,
8
,
9
)])
def
test_perform
(
self
,
shp0
,
shp1
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论