提交 971e1302 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed my update to scan grad

上级 d1e926cb
...@@ -1374,11 +1374,21 @@ class Scan(PureOp): ...@@ -1374,11 +1374,21 @@ class Scan(PureOp):
self.inner_nitsot_outs(self_outputs)) self.inner_nitsot_outs(self_outputs))
def compute_gradient(y, g_y): def compute_gradient(y, g_y):
gmp = gradient.grad_sources_inputs( if 'int' in str(g_y.dtype):
[(y, g_y)], raise TypeError("Gradients may never be integers but g_y "
[x for x in theano.gof.graph.inputs([y]) "has type "+str(g_y.type))
if x in diff_inputs])
return [gmp.get(p, None) for p in diff_inputs] wrt = [x for x in theano.gof.graph.inputs([y])
if x in diff_inputs]
grads = gradient.grad(
cost = None,
known_grads = {y : g_y },
wrt=wrt, consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
gmp = dict(zip(wrt, grads))
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
dC_dinps_t = [None for inp in diff_inputs] dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs] disconnected_dC_dinps_t = [True for inp in diff_inputs]
dC_dXts = [] dC_dXts = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论