提交 40381bca authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made scan treat DisconnectedType as 0

上级 f53a585d
......@@ -30,6 +30,7 @@ from theano.tensor import TensorType
from theano import tensor
from theano.tensor.opt import Shape_i
from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType
#from theano.sandbox import cuda
from theano.compile.profiling import ScanProfileStats
......@@ -1162,6 +1163,18 @@ class Scan(PureOp):
### GRAD FUNCTION
def grad(self, args, g_outs):
#This discards information about whether incoming gradients are 0
#or disconnected from the cost
#TODO: upgrade scan op to report disconnection correctly
def strip_disconnected( g ):
if isinstance(g.type, DisconnectedType):
return None
return g
g_outs = [ strip_disconnected(g) for g in g_outs ]
# 1. forward pass - get the outputs after applying scan
scan_outputs = self(*args)
# 2. make sure they are given as a list
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论