Unverified 提交 be6a0322 authored 作者: Tanish's avatar Tanish 提交者: GitHub

Adds functions to rewrite cholesky decomposition of identity and diagonal matrices (#925)

* fixed merge conflicts * fixed failing tests and added rewrite for pt.diag * minor changes; added test to not apply rewrite * added test for batched case and more cases of not applying rewrite * minor changes
上级 3e98b9f7
...@@ -887,3 +887,82 @@ def rewrite_slogdet_kronecker(fgraph, node): ...@@ -887,3 +887,82 @@ def rewrite_slogdet_kronecker(fgraph, node):
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_remove_useless_cholesky(fgraph, node):
"""
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Find whether cholesky op is being applied
if not isinstance(node.op.core_op, Cholesky):
return None
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
potential_eye = node.inputs[0]
if not (
potential_eye.owner
and isinstance(potential_eye.owner.op, Eye)
and hasattr(potential_eye.owner.inputs[-1], "data")
and potential_eye.owner.inputs[-1].data.item() == 0
):
return None
return [potential_eye]
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
# Find whether cholesky op is being applied
if not isinstance(node.op.core_op, Cholesky):
return None
[input] = node.inputs
# Check for use of pt.diag first
if (
input.owner
and isinstance(input.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(input.owner)
):
diag_input = input.owner.inputs[0]
cholesky_val = pt.diag(diag_input**0.5)
return [cholesky_val]
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(input)
if inputs_or_none is None:
return None
eye_input, non_eye_inputs = inputs_or_none
# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None
[non_eye_input] = non_eye_inputs
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
if non_eye_input.type.broadcastable[-2:] == (False, False):
non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2)
if eye_input.type.ndim > 2:
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
return [eye_input * (non_eye_input**0.5)]
...@@ -803,3 +803,106 @@ def test_slogdet_kronecker_rewrite(): ...@@ -803,3 +803,106 @@ def test_slogdet_kronecker_rewrite():
atol=1e-3 if config.floatX == "float32" else 1e-8, atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8,
) )
def test_cholesky_eye_rewrite():
x = pt.eye(10)
L = pt.linalg.cholesky(x)
f_rewritten = function([], L, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
# Rewrite Test
assert not any(isinstance(node.op, Cholesky) for node in nodes)
# Value Test
x_test = np.eye(10)
L = np.linalg.cholesky(x_test)
rewritten_val = f_rewritten()
assert_allclose(
L,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
@pytest.mark.parametrize(
"shape",
[(), (7,), (7, 7), (5, 7, 7)],
ids=["scalar", "vector", "matrix", "batched"],
)
def test_cholesky_diag_from_eye_mul(shape):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
y = pt.eye(7) * x
# Performing cholesky decomposition using pt.linalg.cholesky
z_cholesky = pt.linalg.cholesky(y)
# REWRITE TEST
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Cholesky) for node in nodes)
# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
cholesky_val = np.linalg.cholesky(x_test_matrix)
rewritten_val = f_rewritten(x_test)
assert_allclose(
cholesky_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_cholesky_diag_from_diag():
x = pt.dvector("x")
x_diag = pt.diag(x)
x_cholesky = pt.linalg.cholesky(x_diag)
# REWRITE TEST
f_rewritten = function([x], x_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Cholesky) for node in nodes)
# NUMERIC VALUE TEST
x_test = np.random.rand(10)
x_test_matrix = np.eye(10) * x_test
cholesky_val = np.linalg.cholesky(x_test_matrix)
rewritten_cholesky = f_rewritten(x_test)
assert_allclose(
cholesky_val,
rewritten_cholesky,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
# Case 1 : y is not a diagonal matrix because of k = -1
x = pt.tensor("x", shape=(7, 7))
y = pt.eye(7, k=-1) * x
z_cholesky = pt.linalg.cholesky(y)
# REWRITE TEST (should not be applied)
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)
# Case 2 : eye is degenerate
x = pt.scalar("x")
y = pt.eye(1) * x
z_cholesky = pt.linalg.cholesky(y)
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论