Unverified 提交 920b409b authored 作者: Pham Nguyen Hung's avatar Pham Nguyen Hung 提交者: GitHub

Add rewrite to merge multiple SVD Ops with different settings (#769)

上级 a8d76381
......@@ -4,13 +4,17 @@ from typing import cast
from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
node_rewriter,
)
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
......@@ -18,6 +22,7 @@ from pytensor.tensor.nlinalg import (
inv,
kron,
pinv,
svd,
)
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
......@@ -377,3 +382,59 @@ def local_lift_through_linalg(
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def svd_uv_merge(fgraph, node):
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
if not isinstance(node.op.core_op, SVD):
return
(x,) = node.inputs
if node.op.core_op.compute_uv:
# compute_uv=True returns [u, s, v].
# if at least u or v is used, no need to rewrite this node.
if (
len(fgraph.clients[node.outputs[0]]) > 0
or len(fgraph.clients[node.outputs[2]]) > 0
):
return
# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False.
# First, iterate to see if there is an SVD Op that can be reused.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if not cl.op.core_op.compute_uv:
return {
node.outputs[1]: cl.outputs[0],
}
# If no SVD reusable, return a new one.
return {
node.outputs[1]: svd(
x, full_matrices=node.op.core_op.full_matrices, compute_uv=False
),
}
else:
# compute_uv=False returns [s].
# We want rewrite if there is another one with compute_uv=True.
# For this case, just reuse the `s` from the one with compute_uv=True.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if cl.op.core_op.compute_uv and (
len(fgraph.clients[cl.outputs[0]]) > 0
or len(fgraph.clients[cl.outputs[2]]) > 0
):
return [cl.outputs[1]]
......@@ -15,11 +15,13 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import (
SVD,
Det,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
matrix_inverse,
svd,
)
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import (
......@@ -390,3 +392,67 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
def test_svd_uv_merge():
a = matrix("a")
s_1 = svd(a, full_matrices=False, compute_uv=False)
_, s_2, _ = svd(a, full_matrices=False, compute_uv=True)
_, s_3, _ = svd(a, full_matrices=True, compute_uv=True)
u_4, s_4, v_4 = svd(a, full_matrices=True, compute_uv=True)
# `grad` will introduces an SVD Op with compute_uv=True
# full_matrices = True is not supported for grad of svd
gs = pt.grad(pt.sum(s_1), a)
# 1. compute_uv=False needs rewriting with compute_uv=True
f_1 = pytensor.function([a], gs)
nodes = f_1.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert node.op.compute_uv
svd_counter += 1
assert svd_counter == 1
# 2. compute_uv=True needs rewriting with compute=False, reuse node
f_2 = pytensor.function([a], [s_1, s_2])
nodes = f_2.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
svd_counter += 1
assert svd_counter == 1
# 3. compute_uv=True needs rewriting with compute=False, create new node
# full_matrices needs to retain the value
f_3 = pytensor.function([a], [s_2])
nodes = f_3.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
svd_counter += 1
assert svd_counter == 1
# Case 2 of 3. for a different full_matrices
f_4 = pytensor.function([a], [s_3])
nodes = f_4.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
assert node.op.full_matrices
svd_counter += 1
assert svd_counter == 1
# 4. No rewrite should happen
f_5 = pytensor.function([a], [u_4])
nodes = f_5.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert node.op.full_matrices
assert node.op.compute_uv
svd_counter += 1
assert svd_counter == 1
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论