提交 2783b9cc authored 作者: abalkin's avatar abalkin

Reuse optimised special case code in diagonal().

上级 ca068af0
......@@ -7285,7 +7285,10 @@ class Diagonal(Op):
return self.__class__.__name__
def diagonal(a, offset=0, axis1=0, axis2=1):
return Diagonal(offset, axis1, axis2)(a)
if (offset, axis1, axis2) == (0, 0, 1):
from theano.sandbox.linalg import extract_diag
return extract_diag(a)
return Diagonal(offset, axis1, axis2)(a)
class Diag(Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论