Unverified 提交 f489cf4b authored 作者: Tanish's avatar Tanish 提交者: GitHub

Added rewrite for matrix inv(inv(x)) -> x (#893)

上级 ad27dc75
......@@ -569,3 +569,45 @@ def svd_uv_merge(fgraph, node):
or len(fgraph.clients[cl.outputs[2]]) > 0
):
return [cl.outputs[1]]
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
valid_inverses = (MatrixInverse, MatrixPinv)
# Check if its a valid inverse operation (either inv/pinv)
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
# If the outer operation is not a valid inverse, we do not apply this rewrite
if not isinstance(node.op.core_op, valid_inverses):
return None
potential_inner_inv = node.inputs[0].owner
if potential_inner_inv is None or potential_inner_inv.op is None:
return None
# Check if inner op is blockwise and and possible inv
if not (
potential_inner_inv
and isinstance(potential_inner_inv.op, Blockwise)
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
):
return None
return [potential_inner_inv.inputs[0]]
......@@ -10,6 +10,7 @@ from pytensor import function
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
......@@ -554,3 +555,16 @@ def test_svd_uv_merge():
assert node.op.compute_uv
svd_counter += 1
assert svd_counter == 1
@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"])
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"])
def test_inv_inv_rewrite(inv_op_1, inv_op_2):
def get_pt_function(x, op_name):
return getattr(pt.linalg, op_name)(x)
x = pt.matrix("x")
op1 = get_pt_function(x, inv_op_1)
op2 = get_pt_function(op1, inv_op_2)
rewritten_out = rewrite_graph(op2)
assert rewritten_out == x
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论