提交 38d12f58 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fix bug in gradient wrt to mitsot states

上级 9517a606
...@@ -1289,7 +1289,7 @@ class Scan(PureOp): ...@@ -1289,7 +1289,7 @@ class Scan(PureOp):
for idx in xrange(self.n_mit_mot + self.n_mit_sot): for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = numpy.min(self.tap_array[idx]) mintap = numpy.min(self.tap_array[idx])
maxtap = numpy.max(self.tap_array[idx]) maxtap = numpy.max(self.tap_array[idx])
seq = scan_outputs[offset+idx][::-1] seq = scan_outputs[offset+idx]
for k in self.tap_array[idx]: for k in self.tap_array[idx]:
# We cut the sequence such that seq[i] to correspond to # We cut the sequence such that seq[i] to correspond to
# seq[i-k] # seq[i-k]
...@@ -1300,9 +1300,9 @@ class Scan(PureOp): ...@@ -1300,9 +1300,9 @@ class Scan(PureOp):
if maxtap == mintap and maxtap != 0: if maxtap == mintap and maxtap != 0:
nw_seq =seq[:abs(maxtap)] nw_seq =seq[:abs(maxtap)]
elif maxtap -k != 0 : elif maxtap -k != 0 :
nw_seq = seq[dim_offset +k -mintap: -(maxtap -k)] nw_seq = seq[dim_offset +k -mintap - 1: -(maxtap -k + 1)][::-1]
else: else:
nw_seq = seq[dim_offset +k -mintap: ] nw_seq = seq[dim_offset +k -mintap - 1: -1 ][::-1]
if getattr(seq,'name', None) is not None: if getattr(seq,'name', None) is not None:
nw_seq.name = seq.name + '[%d:]'%k nw_seq.name = seq.name + '[%d:]'%k
scan_seqs.append(nw_seq) scan_seqs.append(nw_seq)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论