Unverified 提交 3e98b9f7 authored 作者: Tanish's avatar Tanish 提交者: GitHub

Adding rewrites involving kronecker product (#975)

* Added rewrite for diag of kronecker product * Added rewrite for slogdet; added docstrings for rewrites * fixed typo
上级 56327779
...@@ -22,7 +22,7 @@ from pytensor.tensor.basic import ( ...@@ -22,7 +22,7 @@ from pytensor.tensor.basic import (
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD, SVD,
KroneckerProduct, KroneckerProduct,
...@@ -818,3 +818,72 @@ def rewrite_slogdet_blockdiag(fgraph, node): ...@@ -818,3 +818,72 @@ def rewrite_slogdet_blockdiag(fgraph, node):
) )
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
@register_canonicalize
@register_stabilize
@node_rewriter([ExtractDiag])
def rewrite_diag_kronecker(fgraph, node):
"""
This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector.
diag(kron(a,b)) -> outer(diag(a), diag(b))
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner kron operation
potential_kron = node.inputs[0].owner
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
return None
# Find the matrices
a, b = potential_kron.inputs
diag_a, diag_b = diag(a), diag(b)
outer_prod_as_vector = outer(diag_a, diag_b).flatten()
return [outer_prod_as_vector]
@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_kronecker(fgraph, node):
"""
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner kron operation
potential_kron = node.inputs[0].owner
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
return None
# Find the matrices
a, b = potential_kron.inputs
signs, logdets = zip(*[slogdet(a), slogdet(b)])
sizes = [a.shape[-1], b.shape[-1]]
prod_sizes = prod(sizes, no_zeros_in_input=True)
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
...@@ -751,3 +751,55 @@ def test_slogdet_blockdiag_rewrite(): ...@@ -751,3 +751,55 @@ def test_slogdet_blockdiag_rewrite():
atol=1e-3 if config.floatX == "float32" else 1e-8, atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8,
) )
def test_diag_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
diag_kron_prod = pt.diag(kron_prod)
f_rewritten = function([a, b], diag_kron_prod, mode="FAST_RUN")
# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
# Value Test
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
diag_kron_prod_test = np.diag(kron_prod_test)
rewritten_val = f_rewritten(a_test, b_test)
assert_allclose(
diag_kron_prod_test,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_slogdet_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
# Value Test
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
assert_allclose(
sign_output_test,
rewritten_sign_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
logdet_output_test,
rewritten_logdet_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论