Unverified 提交 920b409b authored 作者: Pham Nguyen Hung's avatar Pham Nguyen Hung 提交者: GitHub

Add rewrite to merge multiple SVD Ops with different settings (#769)

上级 a8d76381
...@@ -4,13 +4,17 @@ from typing import cast ...@@ -4,13 +4,17 @@ from typing import cast
from pytensor import Variable from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.graph.rewriting.basic import (
copy_stack_trace,
node_rewriter,
)
from pytensor.tensor.basic import TensorVariable, diagonal from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct, KroneckerProduct,
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
...@@ -18,6 +22,7 @@ from pytensor.tensor.nlinalg import ( ...@@ -18,6 +22,7 @@ from pytensor.tensor.nlinalg import (
inv, inv,
kron, kron,
pinv, pinv,
svd,
) )
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
register_canonicalize, register_canonicalize,
...@@ -377,3 +382,59 @@ def local_lift_through_linalg( ...@@ -377,3 +382,59 @@ def local_lift_through_linalg(
return [block_diag(*inner_matrices)] return [block_diag(*inner_matrices)]
else: else:
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def svd_uv_merge(fgraph, node):
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
if not isinstance(node.op.core_op, SVD):
return
(x,) = node.inputs
if node.op.core_op.compute_uv:
# compute_uv=True returns [u, s, v].
# if at least u or v is used, no need to rewrite this node.
if (
len(fgraph.clients[node.outputs[0]]) > 0
or len(fgraph.clients[node.outputs[2]]) > 0
):
return
# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False.
# First, iterate to see if there is an SVD Op that can be reused.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if not cl.op.core_op.compute_uv:
return {
node.outputs[1]: cl.outputs[0],
}
# If no SVD reusable, return a new one.
return {
node.outputs[1]: svd(
x, full_matrices=node.op.core_op.full_matrices, compute_uv=False
),
}
else:
# compute_uv=False returns [s].
# We want rewrite if there is another one with compute_uv=True.
# For this case, just reuse the `s` from the one with compute_uv=True.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if cl.op.core_op.compute_uv and (
len(fgraph.clients[cl.outputs[0]]) > 0
or len(fgraph.clients[cl.outputs[2]]) > 0
):
return [cl.outputs[1]]
...@@ -15,11 +15,13 @@ from pytensor.tensor.blockwise import Blockwise ...@@ -15,11 +15,13 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD,
Det, Det,
KroneckerProduct, KroneckerProduct,
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
matrix_inverse, matrix_inverse,
svd,
) )
from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
...@@ -390,3 +392,67 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): ...@@ -390,3 +392,67 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals] test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
def test_svd_uv_merge():
a = matrix("a")
s_1 = svd(a, full_matrices=False, compute_uv=False)
_, s_2, _ = svd(a, full_matrices=False, compute_uv=True)
_, s_3, _ = svd(a, full_matrices=True, compute_uv=True)
u_4, s_4, v_4 = svd(a, full_matrices=True, compute_uv=True)
# `grad` will introduces an SVD Op with compute_uv=True
# full_matrices = True is not supported for grad of svd
gs = pt.grad(pt.sum(s_1), a)
# 1. compute_uv=False needs rewriting with compute_uv=True
f_1 = pytensor.function([a], gs)
nodes = f_1.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert node.op.compute_uv
svd_counter += 1
assert svd_counter == 1
# 2. compute_uv=True needs rewriting with compute=False, reuse node
f_2 = pytensor.function([a], [s_1, s_2])
nodes = f_2.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
svd_counter += 1
assert svd_counter == 1
# 3. compute_uv=True needs rewriting with compute=False, create new node
# full_matrices needs to retain the value
f_3 = pytensor.function([a], [s_2])
nodes = f_3.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
svd_counter += 1
assert svd_counter == 1
# Case 2 of 3. for a different full_matrices
f_4 = pytensor.function([a], [s_3])
nodes = f_4.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
assert node.op.full_matrices
svd_counter += 1
assert svd_counter == 1
# 4. No rewrite should happen
f_5 = pytensor.function([a], [u_4])
nodes = f_5.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert node.op.full_matrices
assert node.op.compute_uv
svd_counter += 1
assert svd_counter == 1
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论