提交 4575fbde authored 作者: abalkin's avatar abalkin

WIP: Implemented diagonal() for ndim=2 case.

上级 9da50f5d
......@@ -1637,6 +1637,9 @@ class _tensor_py_operators:
def ravel(self):
return flatten(self)
def diagonal(self):
return diagonal(self)
# CASTING
def astype(self, dtype):
return cast(self, dtype)
......@@ -7180,3 +7183,34 @@ def all(x, axis=None, keepdims=False):
if keepdims:
out = makeKeepDims(x, out, axis)
return out
class Diagonal(Op):
"""Return specified diagonals.
:param x: A tensor variable with x.ndim >= 2.
:return: A vector representing the diagonal elements.
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x):
return Apply(self, [x], [tensor(dtype=x.dtype,
broadcastable=[False] * (x.ndim -1))])
def perform(self, node, (x,), (z,)):
z[0] = x.diagonal()
def grad(self, (x,), (gz,)):
return [square_diagonal(gz)]
def infer_shape(self, nodes, shapes):
return [(minimum(*shapes[0]), )]
def __str__(self):
return self.__class__.__name__
diagonal = Diagonal()
......@@ -7081,6 +7081,11 @@ class TestTensorInstanceMethods(unittest.TestCase):
x, _ = self.vals
assert_array_equal(X.ravel().eval({X: x}), x.ravel())
def test_diagonal(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.diagonal().eval({X: x}), x.diagonal())
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论