Unverified 提交 94e9ef06 authored 作者: Tanish's avatar Tanish 提交者: GitHub

Rewrite determinant of diagonal matrix as product of diagonal (#797)

* Added det-diag rewrite * fixed pt.diagonal error * Added test for rewrite * Added test for rewrite * fixed test * added check for verifying rewrite * fixed other failing test * added docstring * updated docstring * fixed mypy error * added det_diag_from_diag and test * fixed node rewriter name * added row/col tests * updated check for eye * updated rewrite and tests * added check for eye_input and new test for cases where not to apply rewrite * does not apply rewrite to specific cases * typecasted test variable * typecast variables * removed shape known check; fails for rectangle eye * added new tests for (1,1) eye and rectangle eye * added helper function for diag from eye_mul * updated case for no rewrite which was failing tests * cleaned code; updated rectangle_eye test which is an invalid rewrite * add check for k in pt.eye * Update pytensor/tensor/rewriting/linalg.py Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com> * typecasted det_val * fixed final typecasting * fixed merge * fixed failing rectangle eye test * fixed typo --------- Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com>
上级 bf8a1b5a
...@@ -5,13 +5,15 @@ from typing import cast ...@@ -5,13 +5,15 @@ from typing import cast
from pytensor import Variable from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
PatternNodeRewriter,
copy_stack_trace, copy_stack_trace,
node_rewriter, node_rewriter,
) )
from pytensor.tensor.basic import TensorVariable, diagonal from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD, SVD,
...@@ -39,6 +41,7 @@ from pytensor.tensor.slinalg import ( ...@@ -39,6 +41,7 @@ from pytensor.tensor.slinalg import (
solve, solve,
solve_triangular, solve_triangular,
) )
from pytensor.tensor.subtensor import advanced_set_subtensor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -384,6 +387,104 @@ def local_lift_through_linalg( ...@@ -384,6 +387,104 @@ def local_lift_through_linalg(
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def _find_diag_from_eye_mul(potential_mul_input):
# Check if the op is Elemwise and mul
if not (
potential_mul_input.owner is not None
and isinstance(potential_mul_input.owner.op, Elemwise)
and isinstance(potential_mul_input.owner.op.scalar_op, Mul)
):
return None
# Find whether any of the inputs to mul is Eye
inputs_to_mul = potential_mul_input.owner.inputs
eye_input = [
mul_input
for mul_input in inputs_to_mul
if mul_input.owner and isinstance(mul_input.owner.op, Eye)
]
# Check if 1's are being put on the main diagonal only (k = 0)
if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0:
return None
# If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite
if eye_input and eye_input[0].broadcastable[-2:] != (False, False):
return None
# Get all non Eye inputs (scalars/matrices/vectors)
non_eye_inputs = list(set(inputs_to_mul) - set(eye_input))
return eye_input, non_eye_inputs
@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([det])
def rewrite_det_diag_from_eye_mul(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements.
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
potential_mul_input = node.inputs[0]
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
if eye_non_eye_inputs is None:
return None
eye_input, non_eye_inputs = eye_non_eye_inputs
# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
# Checking if original x was scalar/vector/matrix
if useful_non_eye.type.broadcastable[-2:] == (True, True):
# For scalar
det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0])
elif useful_non_eye.type.broadcastable[-2:] == (False, False):
# For Matrix
det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
else:
# For vector
det_val = useful_non_eye.prod(axis=(-1, -2))
det_val = det_val.astype(node.outputs[0].type.dtype)
return [det_val]
arange = ARange("int64")
det_diag_from_diag = PatternNodeRewriter(
(
det,
(
advanced_set_subtensor,
(alloc, 0, "sh1", "sh2"),
"x",
(arange, 0, "stop", 1),
(arange, 0, "stop", 1),
),
),
(prod, "x"),
name="det_diag_from_diag",
allow_multiple_clients=True,
)
register_canonicalize(det_diag_from_diag)
register_stabilize(det_diag_from_diag)
register_specialize(det_diag_from_diag)
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
......
...@@ -394,6 +394,95 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): ...@@ -394,6 +394,95 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
@pytest.mark.parametrize(
"shape",
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
)
def test_det_diag_from_eye_mul(shape):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
y = pt.eye(7) * x
# Calculating determinant value using pt.linalg.det
z_det = pt.linalg.det(y)
# REWRITE TEST
f_rewritten = function([x], z_det, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Det) for node in nodes)
# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test)
assert_allclose(
det_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_det_diag_from_diag():
x = pt.tensor("x", shape=(None,))
x_diag = pt.diag(x)
y = pt.linalg.det(x_diag)
# REWRITE TEST
f_rewritten = function([x], y, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Det) for node in nodes)
# NUMERIC VALUE TEST
x_test = np.random.rand(7).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test)
assert_allclose(
det_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_dont_apply_det_diag_rewrite_for_1_1():
x = pt.matrix("x")
x_diag = pt.eye(1, 1) * x
y = pt.linalg.det(x_diag)
f_rewritten = function([x], y, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Det) for node in nodes)
# Numeric Value test
x_test = np.random.normal(size=(3, 3)).astype(config.floatX)
x_test_matrix = np.eye(1, 1) * x_test
det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test)
assert_allclose(
det_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
def test_det_diag_incorrect_for_rectangle_eye():
x = pt.matrix("x")
x_diag = pt.eye(7, 5) * x
with pytest.raises(ValueError, match="Determinant not defined"):
pt.linalg.det(x_diag)
def test_svd_uv_merge(): def test_svd_uv_merge():
a = matrix("a") a = matrix("a")
s_1 = svd(a, full_matrices=False, compute_uv=False) s_1 = svd(a, full_matrices=False, compute_uv=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论