提交 43d8e303 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rewrite ExtractDiagonal of AllocDiagonal

上级 f7c4e163
...@@ -19,6 +19,7 @@ from pytensor.scalar import Add, ScalarConstant, ScalarType ...@@ -19,6 +19,7 @@ from pytensor.scalar import Add, ScalarConstant, ScalarType
from pytensor.scalar import constant as scalar_constant from pytensor.scalar import constant as scalar_constant
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
ExtractDiag,
Join, Join,
ScalarFromTensor, ScalarFromTensor,
TensorFromScalar, TensorFromScalar,
...@@ -26,6 +27,7 @@ from pytensor.tensor.basic import ( ...@@ -26,6 +27,7 @@ from pytensor.tensor.basic import (
cast, cast,
concatenate, concatenate,
expand_dims, expand_dims,
full,
get_scalar_constant_value, get_scalar_constant_value,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
register_infer_shape, register_infer_shape,
...@@ -1793,3 +1795,96 @@ optdb["specialize"].register( ...@@ -1793,3 +1795,96 @@ optdb["specialize"].register(
"numba", "numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested use_db_name_as_tag=False, # Not included if only "specialize" is requested
) )
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([ExtractDiag])
def extract_diag_of_diagonal_set_subtensor(fgraph, node):
"""Undo extract diagonal from a set diagonal
This rewrites the following pattern:
y = write_diagonal*(x, x_diag, offset=k1)
z = extract_diag(y, offset=k2)
as:
z = diag_x, if k1 == k2
z = x if k1 != k2
* write_diagonal is not an atomic operation, but a sequence of Arange/SetSubtensor operations.
"""
def is_cosntant_arange(var) -> bool:
if not (isinstance(var, TensorConstant) and var.type.ndim == 1):
return False
data = var.data
start, stop = data[0], data[-1] + 1
return data.size == (stop - start) and (data == np.arange(start, stop)).all() # type: ignore
[diag_x] = node.inputs
if not (
diag_x.owner is not None
and isinstance(diag_x.owner.op, AdvancedIncSubtensor)
and diag_x.owner.op.set_instead_of_inc
):
return None
x, y, *idxs = diag_x.owner.inputs
if not (
x.type.ndim >= 2
and None not in x.type.shape[-2:]
and x.type.shape[-2] == x.type.shape[-1]
):
# TODO: for now we only support rewrite with static square shape for x
return None
op = node.op
if op.axis2 > len(idxs):
return None
# Check all non-axis indices are full slices
axis = {op.axis1, op.axis2}
if not all(is_full_slice(idx) for i, idx in enumerate(idxs) if i not in axis):
return None
# Check axis indices are arange we would expect from setting on the diagonal
axis1_idx = idxs[op.axis1]
axis2_idx = idxs[op.axis2]
if not (is_cosntant_arange(axis1_idx) and is_cosntant_arange(axis2_idx)):
return None
dim_length = x.type.shape[-1]
offset = op.offset
start_stop1 = (axis1_idx.data[0], axis1_idx.data[-1] + 1)
start_stop2 = (axis2_idx.data[0], axis2_idx.data[-1] + 1)
orig_start1, orig_start2 = start_stop1[0], start_stop2[0]
if offset < 0:
# The logic for checking if we are selecting or not a diagonal for negative offset is the same
# as the one with positive offset but swapped axis
start_stop1, start_stop2 = start_stop2, start_stop1
offset = -offset
start1, stop1 = start_stop1
start2, stop2 = start_stop2
if (
start1 == 0
and start2 == offset
and stop1 == dim_length - offset
and stop2 == dim_length
):
# We are extracting the just written diagonal
if y.type.ndim == 0 or y.type.shape[-1] == 1:
# We may need to broadcast y
y = full((*x.shape[:-2], dim_length - offset), y, dtype=x.type.dtype)
return [y]
elif (orig_start2 - orig_start1) != op.offset:
# Some other diagonal was written, ignore it
return [op(x)]
else:
# A portion, but no the whole diagonal was written, don't do anything
return None
import random
import numpy as np import numpy as np
import pytest import pytest
...@@ -9,7 +11,7 @@ from pytensor.compile.function import function ...@@ -9,7 +11,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import vectorize_graph from pytensor.graph import rewrite_graph, vectorize_graph
from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
...@@ -1956,3 +1958,37 @@ class TestUselessSlice: ...@@ -1956,3 +1958,37 @@ class TestUselessSlice:
f(test_x, -2), f(test_x, -2),
test_x[0:3:-2, -1:-6:2, ::], test_x[0:3:-2, -1:-6:2, ::],
) )
def test_extract_diag_of_diagonal_set_subtensor():
A = pt.full((2, 6, 6), np.nan)
rows = pt.arange(A.shape[-2])
cols = pt.arange(A.shape[-1])
write_offsets = [-2, -1, 0, 1, 2]
# Randomize order of write operations, to make sure rewrite is not sensitive to it
random.shuffle(write_offsets)
for offset in write_offsets:
value = offset + 0.1 * offset
if offset == 0:
A = A[..., rows, cols].set(value)
elif offset > 0:
A = A[..., rows[:-offset], cols[offset:]].set(value)
else:
offset = -offset
A = A[..., rows[offset:], cols[:-offset]].set(value)
# Add a partial diagonal along offset 3
A = A[..., rows[1:-3], cols[4:]].set(np.pi)
read_offsets = [-2, -1, 0, 1, 2, 3]
outs = [A.diagonal(offset=offset, axis1=-2, axis2=-1) for offset in read_offsets]
rewritten_outs = rewrite_graph(outs, include=("ShapeOpt", "canonicalize"))
# Every output should just be an Alloc with value
expected_outs = []
for offset in read_offsets[:-1]:
value = np.asarray(offset + 0.1 * offset, dtype=A.type.dtype)
expected_outs.append(pt.full((np.int64(2), np.int8(6 - abs(offset))), value))
# The partial diagonal shouldn't be rewritten
expected_outs.append(outs[-1])
assert equal_computations(rewritten_outs, expected_outs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论