Unverified 提交 bad8d201 authored 作者: Tanish's avatar Tanish 提交者: GitHub

Slogdet returns naive expression and is optimized later (#1041)

上级 33a4d488
......@@ -11,6 +11,7 @@ from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
......@@ -266,7 +267,33 @@ class SLogDet(Op):
return "SLogDet"
slogdet = Blockwise(SLogDet())
def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
"""
Compute the sign and (natural) logarithm of the determinant of an array.
Returns a naive graph which is optimized later using rewrites with the det operation.
Parameters
----------
x : (..., M, M) tensor or tensor_like
Input tensor, has to be square.
Returns
-------
A tuple with the following attributes:
sign : (...) tensor_like
A number representing the sign of the determinant. For a real matrix,
this is 1, 0, or -1.
logabsdet : (...) tensor_like
The natural log of the absolute value of the determinant.
If the determinant is zero, then `sign` will be 0 and `logabsdet`
will be -inf. In all cases, the determinant is equal to
``sign * exp(logabsdet)``.
"""
det_val = det(x)
return ptm.sign(det_val), ptm.log(ptm.abs(det_val))
class Eig(Op):
......
......@@ -2,6 +2,8 @@ import logging
from collections.abc import Callable
from typing import cast
import numpy as np
from pytensor import Variable
from pytensor import tensor as pt
from pytensor.compile import optdb
......@@ -11,7 +13,7 @@ from pytensor.graph.rewriting.basic import (
in2out,
node_rewriter,
)
from pytensor.scalar.basic import Mul
from pytensor.scalar.basic import Abs, Log, Mul, Sign
from pytensor.tensor.basic import (
AllocDiag,
ExtractDiag,
......@@ -30,11 +32,11 @@ from pytensor.tensor.nlinalg import (
KroneckerProduct,
MatrixInverse,
MatrixPinv,
SLogDet,
det,
inv,
kron,
pinv,
slogdet,
svd,
)
from pytensor.tensor.rewriting.basic import (
......@@ -785,45 +787,6 @@ def rewrite_det_blockdiag(fgraph, node):
return [prod(det_sub_matrices)]
@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_blockdiag(fgraph, node):
"""
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None
# Find the composing sub_matrices
sub_matrices = potential_block_diag.inputs
sign_sub_matrices, logdet_sub_matrices = zip(
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
)
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
@register_canonicalize
@register_stabilize
@node_rewriter([ExtractDiag])
......@@ -860,10 +823,10 @@ def rewrite_diag_kronecker(fgraph, node):
@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_kronecker(fgraph, node):
@node_rewriter([det])
def rewrite_det_kronecker(fgraph, node):
"""
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those
Parameters
----------
......@@ -884,13 +847,12 @@ def rewrite_slogdet_kronecker(fgraph, node):
# Find the matrices
a, b = potential_kron.inputs
signs, logdets = zip(*[slogdet(a), slogdet(b)])
dets = [det(a), det(b)]
sizes = [a.shape[-1], b.shape[-1]]
prod_sizes = prod(sizes, no_zeros_in_input=True)
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)])
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
return [det_final]
@register_canonicalize
......@@ -989,3 +951,65 @@ optdb.register(
"jax",
position=0.9, # Run before canonicalization
)
@register_specialize
@node_rewriter([det])
def slogdet_specialization(fgraph, node):
"""
This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
dictionary of Variables, optional
Dictionary of nodes and what they should be replaced with, or None if no optimization was performed
"""
dummy_replacements = {}
for client, _ in fgraph.clients[node.outputs[0]]:
# Check for sign(det)
if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign):
dummy_replacements[client.outputs[0]] = "sign"
# Check for log(abs(det))
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs):
potential_log = None
for client_2, _ in fgraph.clients[client.outputs[0]]:
if isinstance(client_2.op, Elemwise) and isinstance(
client_2.op.scalar_op, Log
):
potential_log = client_2
if potential_log:
dummy_replacements[potential_log.outputs[0]] = "log_abs_det"
else:
return None
# Check for log(det)
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log):
dummy_replacements[client.outputs[0]] = "log_det"
# Det is used directly for something else, don't rewrite to avoid computing two dets
else:
return None
if not dummy_replacements:
return None
else:
[x] = node.inputs
sign_det_x, log_abs_det_x = SLogDet()(x)
log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x)
slogdet_specialization_map = {
"sign": sign_det_x,
"log_abs_det": log_abs_det_x,
"log_det": log_det_x,
}
replacements = {
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
}
return replacements
from collections.abc import Sequence
import numpy as np
import pytest
......@@ -22,13 +24,13 @@ def matrix_test():
@pytest.mark.parametrize(
"func",
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det),
(pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
)
def test_lin_alg_no_params(func, matrix_test):
x, test_value = matrix_test
out = func(x)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out])
def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)
......
......@@ -21,6 +21,7 @@ from pytensor.tensor.nlinalg import (
KroneckerProduct,
MatrixInverse,
MatrixPinv,
SLogDet,
matrix_inverse,
svd,
)
......@@ -719,7 +720,7 @@ def test_det_blockdiag_rewrite():
def test_slogdet_blockdiag_rewrite():
n_matrices = 100
n_matrices = 10
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
......@@ -776,11 +777,34 @@ def test_diag_kronecker_rewrite():
)
def test_det_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
det_output = pt.linalg.det(kron_prod)
f_rewritten = function([a, b], [det_output], mode="FAST_RUN")
# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
# Value Test
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
det_output_test = np.linalg.det(kron_prod_test)
rewritten_det_val = f_rewritten(a_test, b_test)
assert_allclose(
det_output_test,
rewritten_det_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_slogdet_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
f_rewritten = function([a, b], [sign_output, logdet_output], mode="FAST_RUN")
# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
......@@ -790,7 +814,7 @@ def test_slogdet_kronecker_rewrite():
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(a_test, b_test)
assert_allclose(
sign_output_test,
rewritten_sign_val,
......@@ -906,3 +930,69 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)
def test_slogdet_specialization():
x, a = pt.dmatrix("x"), np.random.rand(20, 20)
det_x, det_a = pt.linalg.det(x), np.linalg.det(a)
log_abs_det_x, log_abs_det_a = pt.log(pt.abs(det_x)), np.log(np.abs(det_a))
log_det_x, log_det_a = pt.log(det_x), np.log(det_a)
sign_det_x, sign_det_a = pt.sign(det_x), np.sign(det_a)
exp_det_x = pt.exp(det_x)
# REWRITE TESTS
# sign(det(x))
f = function([x], [sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)
rw_sign_det_a = f(a)
assert_allclose(
sign_det_a,
rw_sign_det_a,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
# log(abs(det(x)))
f = function([x], [log_abs_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)
rw_log_abs_det_a = f(a)
assert_allclose(
log_abs_det_a,
rw_log_abs_det_a,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
# log(det(x))
f = function([x], [log_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)
rw_log_det_a = f(a)
assert_allclose(
log_det_a,
rw_log_det_a,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
# More than 1 valid function
f = function([x], [sign_det_x, log_abs_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)
# Other functions (rewrite shouldnt be applied to these)
# Only invalid functions
f = function([x], [exp_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)
# Invalid + Valid function
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论