提交 92ed3123 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed bug reported by Justtin, when scan would cache the number of steps

from previous runs, which ScanGrad would end up using
上级 6ab9f3dd
...@@ -1179,6 +1179,8 @@ class Scan(Op): ...@@ -1179,6 +1179,8 @@ class Scan(Op):
def make_node(self,*inputs): def make_node(self,*inputs):
assert all(isinstance(i, gof.Variable) for i in inputs) assert all(isinstance(i, gof.Variable) for i in inputs)
self.n_steps = inputs[0]
return Apply(self, inputs, [t() for t in self.apply_output_types]) return Apply(self, inputs, [t() for t in self.apply_output_types])
...@@ -1309,7 +1311,7 @@ class Scan(Op): ...@@ -1309,7 +1311,7 @@ class Scan(Op):
'required by the maximal past value %d. Scan will use 0s' 'required by the maximal past value %d. Scan will use 0s'
' for missing values')%(i-self.n_iterable-1,req_size)) ' for missing values')%(i-self.n_iterable-1,req_size))
self.n_steps = n_steps
y = self.scan(self.fn, args[1:],self.n_seqs, self.n_outs, y = self.scan(self.fn, args[1:],self.n_seqs, self.n_outs,
self.seqs_taps, self.outs_taps, n_steps, go_backwards, self.seqs_taps, self.outs_taps, n_steps, go_backwards,
inplace_map) inplace_map)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论