提交 2c41735a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Vectorize ExtractDiag

Also adds better static shapes
上级 89c7544a
...@@ -26,6 +26,7 @@ from pytensor.graph import RewriteDatabaseQuery ...@@ -26,6 +26,7 @@ from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
...@@ -3497,10 +3498,17 @@ class ExtractDiag(Op): ...@@ -3497,10 +3498,17 @@ class ExtractDiag(Op):
if x.ndim < 2: if x.ndim < 2:
raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x) raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x)
out_shape = [
st_dim
for i, st_dim in enumerate(x.type.shape)
if i not in (self.axis1, self.axis2)
] + [None]
return Apply( return Apply(
self, self,
[x], [x],
[x.type.clone(dtype=x.dtype, shape=(None,) * (x.ndim - 1))()], [x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -3601,6 +3609,17 @@ def diagonal(a, offset=0, axis1=0, axis2=1): ...@@ -3601,6 +3609,17 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
return ExtractDiag(offset, axis1, axis2)(a) return ExtractDiag(offset, axis1, axis2)(a)
@_vectorize_node.register(ExtractDiag)
def vectorize_extract_diag(op: ExtractDiag, node, batched_x):
batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim
return diagonal(
batched_x,
offset=op.offset,
axis1=op.axis1 + batched_ndims,
axis2=op.axis2 + batched_ndims,
).owner
def trace(a, offset=0, axis1=0, axis2=1): def trace(a, offset=0, axis1=0, axis2=1):
""" """
Returns the sum along diagonals of the array. Returns the sum along diagonals of the array.
......
...@@ -20,7 +20,7 @@ from pytensor.graph.replace import clone_replace ...@@ -20,7 +20,7 @@ from pytensor.graph.replace import clone_replace
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar import autocast_float, autocast_float_as from pytensor.scalar import autocast_float, autocast_float_as
from pytensor.tensor import NoneConst from pytensor.tensor import NoneConst, vectorize
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocEmpty, AllocEmpty,
...@@ -88,6 +88,7 @@ from pytensor.tensor.basic import ( ...@@ -88,6 +88,7 @@ from pytensor.tensor.basic import (
vertical_stack, vertical_stack,
zeros_like, zeros_like,
) )
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import dense_dot from pytensor.tensor.math import dense_dot
...@@ -4517,3 +4518,26 @@ def test_trace(): ...@@ -4517,3 +4518,26 @@ def test_trace():
trace(x, offset=-1, axis1=0, axis2=-1).eval(), trace(x, offset=-1, axis1=0, axis2=-1).eval(),
np.trace(x_val, offset=-1, axis1=0, axis2=-1), np.trace(x_val, offset=-1, axis1=0, axis2=-1),
) )
def test_vectorize_extract_diag():
signature = "(a1,b,a2)->(b,a)"
def core_pt(x):
return at.diagonal(x, offset=1, axis1=0, axis2=2)
def core_np(x):
return np.diagonal(x, offset=1, axis1=0, axis2=2)
x = tensor(shape=(5, 5, 5, 5))
vectorize_pt = function([x], vectorize(core_pt, signature=signature)(x))
assert not any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
vectorize_np = np.vectorize(core_np, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test),
vectorize_np(x_test),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论