提交 cdb6f78f authored 作者: abalkin's avatar abalkin 提交者: Frederic

Added a work-around for grad wrt disconnected variables.

上级 9df4514a
...@@ -12,7 +12,7 @@ from theano.tensor.opt import (register_stabilize, ...@@ -12,7 +12,7 @@ from theano.tensor.opt import (register_stabilize,
register_specialize, register_canonicalize) register_specialize, register_canonicalize)
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gradient import grad_not_implemented from theano.gradient import grad_not_implemented, DisconnectedType
1 1
try: try:
import scipy.linalg import scipy.linalg
...@@ -922,6 +922,11 @@ class Eig(Op): ...@@ -922,6 +922,11 @@ class Eig(Op):
eig = Eig() eig = Eig()
def _zero_disconnected(outputs, grads):
return [o.zeros_like()
if isinstance(g.type, DisconnectedType) else g
for o, g in zip(outputs, grads)]
class Eigh(Eig): class Eigh(Eig):
""" """
Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix. Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix.
...@@ -963,7 +968,9 @@ class Eigh(Eig): ...@@ -963,7 +968,9 @@ class Eigh(Eig):
""" """
x, = inputs x, = inputs
w, v = self(x) w, v = self(x)
gw, gv = g_outputs # Replace gradients wrt disconnected variables with
# zeros. This is a work-around for issue #1063.
gw, gv = _zero_disconnected([w, v], g_outputs)
return [EighGrad()(x, w, v, gw, gv)] return [EighGrad()(x, w, v, gw, gv)]
eigh = Eigh() eigh = Eigh()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论