提交 6e60fe28 authored 作者: abalkin's avatar abalkin

WIP: Implemented non-zero offset case in Diagonal.infer_shape().

上级 66ba51fe
......@@ -7216,12 +7216,21 @@ class Diagonal(Op):
return [square_diagonal(gz)]
def infer_shape(self, node, shapes):
xdims = list(shapes[0])
d0 = minimum(xdims[self.axis1], xdims[self.axis2])
xdims = [d for i,d in enumerate(shapes[0])
if i not in (self.axis1, self.axis2)]
xdims.append(d0)
return [tuple(xdims)]
in_shape, = shapes
dim1 = in_shape[self.axis1]
dim2 = in_shape[self.axis2]
out_shape = [d for i,d in enumerate(in_shape)
if i not in (self.axis1, self.axis2)]
# The following logic is inspired by C code of PyArray_Diagonal().
offset = self.offset
if offset > 0:
diag_size = clip(dim2 - offset, 0, dim1)
elif offset < 0:
diag_size = clip(dim1 + offset, 0, dim2)
else:
diag_size = minimum(dim1, dim2)
out_shape.append(diag_size)
return [tuple(out_shape)]
def __str__(self):
return self.__class__.__name__
......
......@@ -6546,6 +6546,15 @@ class TestInferShape(utt.InferShapeTester):
atens3 = tensor3()
atens3_val = rand(4, 5, 3)
atens3_diag = Diagonal()(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(-1)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1,0,2)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
# Shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论