Unverified 提交 3eea7d0e authored 作者: Tanish's avatar Tanish 提交者: GitHub

Added rewrites involving block diagonal matrices (#967)

* added rewrite for diag(block_diag) * added rewrite for determinant of blockdiag * Added rewrite for slogdet; added docstrings for all 3 rewrites * fixed typecasting for tests
上级 2086aeb8
......@@ -12,8 +12,11 @@ from pytensor.graph.rewriting.basic import (
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import (
AllocDiag,
ExtractDiag,
Eye,
TensorVariable,
concatenate,
diag,
diagonal,
)
from pytensor.tensor.blas import Dot22
......@@ -29,6 +32,7 @@ from pytensor.tensor.nlinalg import (
inv,
kron,
pinv,
slogdet,
svd,
)
from pytensor.tensor.rewriting.basic import (
......@@ -701,3 +705,116 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)
return [eye_input / non_eye_input]
@register_canonicalize
@register_stabilize
@node_rewriter([ExtractDiag])
def rewrite_diag_blockdiag(fgraph, node):
"""
This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices.
diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...)
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None
# Find the composing sub_matrices
submatrices = potential_block_diag.inputs
submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))]
return [concatenate(submatrices_diag)]
@register_canonicalize
@register_stabilize
@node_rewriter([det])
def rewrite_det_blockdiag(fgraph, node):
"""
This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values.
det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...)
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None
# Find the composing sub_matrices
sub_matrices = potential_block_diag.inputs
det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))]
return [prod(det_sub_matrices)]
@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_blockdiag(fgraph, node):
"""
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None
# Find the composing sub_matrices
sub_matrices = potential_block_diag.inputs
sign_sub_matrices, logdet_sub_matrices = zip(
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
)
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
......@@ -662,3 +662,92 @@ def test_inv_diag_from_diag(inv_op):
atol=ATOL,
rtol=RTOL,
)
def test_diag_blockdiag_rewrite():
n_matrices = 10
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
diag_output = pt.diag(bd_output)
f_rewritten = function([sub_matrices], diag_output, mode="FAST_RUN")
# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
# Value Test
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
bd_output_test = scipy.linalg.block_diag(
*[sub_matrices_test[i] for i in range(n_matrices)]
)
diag_output_test = np.diag(bd_output_test)
rewritten_val = f_rewritten(sub_matrices_test)
assert_allclose(
diag_output_test,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_det_blockdiag_rewrite():
n_matrices = 100
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
det_output = pt.linalg.det(bd_output)
f_rewritten = function([sub_matrices], det_output, mode="FAST_RUN")
# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
# Value Test
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
bd_output_test = scipy.linalg.block_diag(
*[sub_matrices_test[i] for i in range(n_matrices)]
)
det_output_test = np.linalg.det(bd_output_test)
rewritten_val = f_rewritten(sub_matrices_test)
assert_allclose(
det_output_test,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_slogdet_blockdiag_rewrite():
n_matrices = 100
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
sign_output, logdet_output = pt.linalg.slogdet(bd_output)
f_rewritten = function(
[sub_matrices], [sign_output, logdet_output], mode="FAST_RUN"
)
# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
# Value Test
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
bd_output_test = scipy.linalg.block_diag(
*[sub_matrices_test[i] for i in range(n_matrices)]
)
sign_output_test, logdet_output_test = np.linalg.slogdet(bd_output_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(sub_matrices_test)
assert_allclose(
sign_output_test,
rewritten_sign_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
logdet_output_test,
rewritten_logdet_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论