提交 6e60fe28 authored 作者: abalkin's avatar abalkin

WIP: Implemented non-zero offset case in Diagonal.infer_shape().

上级 66ba51fe
...@@ -7216,12 +7216,21 @@ class Diagonal(Op): ...@@ -7216,12 +7216,21 @@ class Diagonal(Op):
return [square_diagonal(gz)] return [square_diagonal(gz)]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xdims = list(shapes[0]) in_shape, = shapes
d0 = minimum(xdims[self.axis1], xdims[self.axis2]) dim1 = in_shape[self.axis1]
xdims = [d for i,d in enumerate(shapes[0]) dim2 = in_shape[self.axis2]
out_shape = [d for i,d in enumerate(in_shape)
if i not in (self.axis1, self.axis2)] if i not in (self.axis1, self.axis2)]
xdims.append(d0) # The following logic is inspired by C code of PyArray_Diagonal().
return [tuple(xdims)] offset = self.offset
if offset > 0:
diag_size = clip(dim2 - offset, 0, dim1)
elif offset < 0:
diag_size = clip(dim1 + offset, 0, dim2)
else:
diag_size = minimum(dim1, dim2)
out_shape.append(diag_size)
return [tuple(out_shape)]
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
......
...@@ -6546,6 +6546,15 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6546,6 +6546,15 @@ class TestInferShape(utt.InferShapeTester):
atens3 = tensor3() atens3 = tensor3()
atens3_val = rand(4, 5, 3) atens3_val = rand(4, 5, 3)
atens3_diag = Diagonal()(atens3) atens3_diag = Diagonal()(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(-1)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1,0,2)(atens3)
self._compile_and_check([atens3], [atens3_diag], self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal) [atens3_val], Diagonal)
# Shape # Shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论