提交 ca068af0 authored 作者: abalkin's avatar abalkin

More tests for Diagonal.

上级 726aeae5
......@@ -6603,6 +6603,12 @@ class TestInferShape(utt.InferShapeTester):
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1,0,2)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1,1,2)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1,2,0)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
......@@ -7157,6 +7163,9 @@ class TestTensorInstanceMethods(unittest.TestCase):
assert_array_equal(X.diagonal().eval({X: x}), x.diagonal())
assert_array_equal(X.diagonal(1).eval({X: x}), x.diagonal(1))
assert_array_equal(X.diagonal(-1).eval({X: x}), x.diagonal(-1))
for offset, axis1, axis2 in [(1,0,1), (-1,0,1), (0,1,0), (-2,1,0)]:
assert_array_equal(X.diagonal(offset, axis1, axis2).eval({X: x}),
x.diagonal(offset, axis1, axis2))
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论