Unverified 提交 3e98b9f7 authored 作者: Tanish's avatar Tanish 提交者: GitHub

Adding rewrites involving kronecker product (#975)

* Added rewrite for diag of kronecker product * Added rewrite for slogdet; added docstrings for rewrites * fixed typo
上级 56327779
......@@ -22,7 +22,7 @@ from pytensor.tensor.basic import (
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
......@@ -818,3 +818,72 @@ def rewrite_slogdet_blockdiag(fgraph, node):
)
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
@register_canonicalize
@register_stabilize
@node_rewriter([ExtractDiag])
def rewrite_diag_kronecker(fgraph, node):
"""
This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector.
diag(kron(a,b)) -> outer(diag(a), diag(b))
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner kron operation
potential_kron = node.inputs[0].owner
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
return None
# Find the matrices
a, b = potential_kron.inputs
diag_a, diag_b = diag(a), diag(b)
outer_prod_as_vector = outer(diag_a, diag_b).flatten()
return [outer_prod_as_vector]
@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_kronecker(fgraph, node):
"""
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner kron operation
potential_kron = node.inputs[0].owner
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
return None
# Find the matrices
a, b = potential_kron.inputs
signs, logdets = zip(*[slogdet(a), slogdet(b)])
sizes = [a.shape[-1], b.shape[-1]]
prod_sizes = prod(sizes, no_zeros_in_input=True)
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
......@@ -751,3 +751,55 @@ def test_slogdet_blockdiag_rewrite():
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_diag_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
diag_kron_prod = pt.diag(kron_prod)
f_rewritten = function([a, b], diag_kron_prod, mode="FAST_RUN")
# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
# Value Test
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
diag_kron_prod_test = np.diag(kron_prod_test)
rewritten_val = f_rewritten(a_test, b_test)
assert_allclose(
diag_kron_prod_test,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_slogdet_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
# Value Test
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
assert_allclose(
sign_output_test,
rewritten_sign_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
logdet_output_test,
rewritten_logdet_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论