提交 6c67891a authored 作者: abalkin's avatar abalkin

WIP: Implemented diagonal(x, offset=0, axis1=0, axis2=1).

上级 4575fbde
......@@ -1637,8 +1637,8 @@ class _tensor_py_operators:
def ravel(self):
return flatten(self)
def diagonal(self):
return diagonal(self)
def diagonal(self, offset=0, axis1=0, axis2=1):
return diagonal(self, offset, axis1, axis2)
# CASTING
def astype(self, dtype):
......@@ -7191,19 +7191,22 @@ class Diagonal(Op):
:return: A vector representing the diagonal elements.
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x):
return Apply(self, [x], [tensor(dtype=x.dtype,
broadcastable=[False] * (x.ndim -1))])
def make_node(self, x, offset=0, axis1=0, axis2=1):
x = as_tensor_variable(x)
assert x.ndim >= 2
offset, axis1, axis2 = map(scal.as_scalar, (offset, axis1, axis2))
return Apply(self, [x, offset, axis1, axis2], [tensor(dtype=x.dtype,
broadcastable=[False] * (x.ndim -1))])
def perform(self, node, (x,), (z,)):
z[0] = x.diagonal()
def perform(self, node, (x, off, ax1, ax2), (z,)):
z[0] = x.diagonal(off, ax1, ax2)
def grad(self, (x,), (gz,)):
return [square_diagonal(gz)]
......
......@@ -7085,6 +7085,8 @@ class TestTensorInstanceMethods(unittest.TestCase):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.diagonal().eval({X: x}), x.diagonal())
assert_array_equal(X.diagonal(1).eval({X: x}), x.diagonal(1))
assert_array_equal(X.diagonal(-1).eval({X: x}), x.diagonal(-1))
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论