Unverified 提交 1a1c62bb authored 作者: Tanish's avatar Tanish 提交者: GitHub

added rewrites for inv(diag(x)) and inv(eye) (#898)

* updated tests * updated rewrites * paramterized tests and added batch case * minor changes
上级 7eca2527
......@@ -3,6 +3,7 @@ from collections.abc import Callable
from typing import cast
from pytensor import Variable
from pytensor import tensor as pt
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
......@@ -48,6 +49,7 @@ from pytensor.tensor.slinalg import (
logger = logging.getLogger(__name__)
ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv)
def is_matrix_transpose(x: TensorVariable) -> bool:
......@@ -592,11 +594,10 @@ def rewrite_inv_inv(fgraph, node):
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
valid_inverses = (MatrixInverse, MatrixPinv)
# Check if its a valid inverse operation (either inv/pinv)
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
# If the outer operation is not a valid inverse, we do not apply this rewrite
if not isinstance(node.op.core_op, valid_inverses):
if not isinstance(node.op.core_op, ALL_INVERSE_OPS):
return None
potential_inner_inv = node.inputs[0].owner
......@@ -607,7 +608,96 @@ def rewrite_inv_inv(fgraph, node):
if not (
potential_inner_inv
and isinstance(potential_inner_inv.op, Blockwise)
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS)
):
return None
return [potential_inner_inv.inputs[0]]
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_eye_to_eye(fgraph, node):
"""
This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
core_op = node.op.core_op
if not (isinstance(core_op, ALL_INVERSE_OPS)):
return None
# Check whether input to inverse is Eye and the 1's are on main diagonal
potential_eye = node.inputs[0]
if not (
potential_eye.owner
and isinstance(potential_eye.owner.op, Eye)
and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0
):
return None
return [potential_eye]
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
core_op = node.op.core_op
if not (isinstance(core_op, ALL_INVERSE_OPS)):
return None
inputs = node.inputs[0]
# Check for use of pt.diag first
if (
inputs.owner
and isinstance(inputs.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(inputs.owner)
):
inv_input = inputs.owner.inputs[0]
inv_val = pt.diag(1 / inv_input)
return [inv_val]
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(inputs)
if inputs_or_none is None:
return None
eye_input, non_eye_inputs = inputs_or_none
# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None
non_eye_input = non_eye_inputs[0]
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
if non_eye_input.type.broadcastable[-2:] == (False, False):
non_eye_diag = non_eye_input.diagonal(axis1=-1, axis2=-2)
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)
return [eye_input / non_eye_input]
......@@ -41,6 +41,9 @@ from tests import unittest_tools as utt
from tests.test_rop import break_op
ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8
def test_rop_lop():
mx = matrix("mx")
mv = matrix("mv")
......@@ -557,14 +560,105 @@ def test_svd_uv_merge():
assert svd_counter == 1
def get_pt_function(x, op_name):
return getattr(pt.linalg, op_name)(x)
@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"])
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"])
def test_inv_inv_rewrite(inv_op_1, inv_op_2):
def get_pt_function(x, op_name):
return getattr(pt.linalg, op_name)(x)
x = pt.matrix("x")
op1 = get_pt_function(x, inv_op_1)
op2 = get_pt_function(op1, inv_op_2)
rewritten_out = rewrite_graph(op2)
assert rewritten_out == x
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
def test_inv_eye_to_eye(inv_op):
x = pt.eye(10)
x_inv = get_pt_function(x, inv_op)
f_rewritten = function([], x_inv, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
# Rewrite Test
valid_inverses = (MatrixInverse, MatrixPinv)
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
# Value Test
x_test = np.eye(10)
x_inv_val = np.linalg.inv(x_test)
rewritten_val = f_rewritten()
assert_allclose(
x_inv_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
@pytest.mark.parametrize(
"shape",
[(), (7,), (7, 7), (5, 7, 7)],
ids=["scalar", "vector", "matrix", "batched"],
)
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
def test_inv_diag_from_eye_mul(shape, inv_op):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
x_diag = pt.eye(7) * x
# Calculating inverse using pt.linalg.inv
x_inv = get_pt_function(x_diag, inv_op)
# REWRITE TEST
f_rewritten = function([x], x_inv, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
valid_inverses = (MatrixInverse, MatrixPinv)
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
inverse_matrix = np.linalg.inv(x_test_matrix)
rewritten_inverse = f_rewritten(x_test)
assert_allclose(
inverse_matrix,
rewritten_inverse,
atol=ATOL,
rtol=RTOL,
)
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
def test_inv_diag_from_diag(inv_op):
x = pt.dvector("x")
x_diag = pt.diag(x)
x_inv = get_pt_function(x_diag, inv_op)
# REWRITE TEST
f_rewritten = function([x], x_inv, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
valid_inverses = (MatrixInverse, MatrixPinv)
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
# NUMERIC VALUE TEST
x_test = np.random.rand(10)
x_test_matrix = np.eye(10) * x_test
inverse_matrix = np.linalg.inv(x_test_matrix)
rewritten_inverse = f_rewritten(x_test)
assert_allclose(
inverse_matrix,
rewritten_inverse,
atol=ATOL,
rtol=RTOL,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论