提交 2c41735a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Vectorize ExtractDiag

Also adds better static shapes
上级 89c7544a
......@@ -26,6 +26,7 @@ from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp
......@@ -3497,10 +3498,17 @@ class ExtractDiag(Op):
if x.ndim < 2:
raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x)
out_shape = [
st_dim
for i, st_dim in enumerate(x.type.shape)
if i not in (self.axis1, self.axis2)
] + [None]
return Apply(
self,
[x],
[x.type.clone(dtype=x.dtype, shape=(None,) * (x.ndim - 1))()],
[x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()],
)
def perform(self, node, inputs, outputs):
......@@ -3601,6 +3609,17 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
return ExtractDiag(offset, axis1, axis2)(a)
@_vectorize_node.register(ExtractDiag)
def vectorize_extract_diag(op: ExtractDiag, node, batched_x):
batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim
return diagonal(
batched_x,
offset=op.offset,
axis1=op.axis1 + batched_ndims,
axis2=op.axis2 + batched_ndims,
).owner
def trace(a, offset=0, axis1=0, axis2=1):
"""
Returns the sum along diagonals of the array.
......
......@@ -20,7 +20,7 @@ from pytensor.graph.replace import clone_replace
from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert
from pytensor.scalar import autocast_float, autocast_float_as
from pytensor.tensor import NoneConst
from pytensor.tensor import NoneConst, vectorize
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
......@@ -88,6 +88,7 @@ from pytensor.tensor.basic import (
vertical_stack,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import dense_dot
......@@ -4517,3 +4518,26 @@ def test_trace():
trace(x, offset=-1, axis1=0, axis2=-1).eval(),
np.trace(x_val, offset=-1, axis1=0, axis2=-1),
)
def test_vectorize_extract_diag():
signature = "(a1,b,a2)->(b,a)"
def core_pt(x):
return at.diagonal(x, offset=1, axis1=0, axis2=2)
def core_np(x):
return np.diagonal(x, offset=1, axis1=0, axis2=2)
x = tensor(shape=(5, 5, 5, 5))
vectorize_pt = function([x], vectorize(core_pt, signature=signature)(x))
assert not any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
vectorize_np = np.vectorize(core_np, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test),
vectorize_np(x_test),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论