提交 5d3e508f authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix for the grad method

上级 8f087c62
......@@ -1198,6 +1198,37 @@ class Scan(PureOp):
for o, x in izip(node.outputs, scan_outs)]
return scan_outs
def get_input_pos(self, output_index):
ipos = 0
opos = output_index
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(otaps) > opos:
return ipos
else:
opos = opos - len(otaps)
ipos += len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if opos == 0:
return ipos
else:
opos = opos - 1
ipos += len(taps)
if opos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_slice_idx(self, output_index):
ipos = 0
opos = output_index
for otaps in zip(self.mitmot_out_taps()):
if len(otaps) > 0:
return ipos
else:
opos = opos - 1
ipos += len(otaps)
return ipos + opos
### GRAD FUNCTION
def grad(self, args, g_outs):
......@@ -1303,12 +1334,12 @@ class Scan(PureOp):
# 7.3. compute gradients of the inputs given one output
for dx, out in enumerate(clean_outputs):
if g_outs[dx] != None:
inner_g_out = safe_new(g_outs[dx][0])
input_pos = self.get_input_pos(dx)
if input_pos >= 0:
corresponding_input = self_inputs[input_pos]
tmp = tensor.grad(out.sum(), corresponding_input)
inner_g_out = safe_new(tmp)
else:
# We do not have a gradient on this output so we need a
# placeholder, which for now has the same dtype as the
# output
inner_g_out = safe_new(out)
###
#### I need to clip the gradient HERE !!
......@@ -1356,7 +1387,9 @@ class Scan(PureOp):
# this try is for catching non ndarray inputs (random
# states) it is more of a safety check ( all random
# states should be after n_outs_not_shared ...
g_outs[i] = tensor.zeros_like(scan_outputs[i])
g_outs[i] = tensor.cast(
tensor.zeros_like(scan_outputs[i]),
inner_gfn_outs[self.get_output_slice_idx(i)].dtype)
except Exception:
g_outs[i] = theano.tensor.constant(
numpy.array(0, theano.config.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论