提交 783713f1 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier 提交者: --global

In grad(), return DisconnectedType() instead of 0s

上级 993df98e
...@@ -3,6 +3,7 @@ import pdb ...@@ -3,6 +3,7 @@ import pdb
import theano import theano
import theano.tensor as T import theano.tensor as T
from theano.gof import Op, Apply from theano.gof import Op, Apply
from theano.gradient import DisconnectedType
class PdbBreakpoint(Op): class PdbBreakpoint(Op):
""" """
...@@ -90,8 +91,7 @@ class PdbBreakpoint(Op): ...@@ -90,8 +91,7 @@ class PdbBreakpoint(Op):
output_storage[i][0] = monitored[i] output_storage[i][0] = monitored[i]
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
return ([inputs[0].zeros_like().astype(theano.config.floatX)] + return ([DisconnectedType()] + output_gradients)
output_gradients)
def infer_shape(self, inputs, input_shapes): def infer_shape(self, inputs, input_shapes):
# Return the shape of every input but the condition # Return the shape of every input but the condition
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论