提交 fdeda4f8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed a bug reported by cityhall consisting in not calling tensor.as_tensor on…

Fixed a bug reported by cityhall consisting in not calling tensor.as_tensor on arguments before calling make_node
上级 79ebf621
...@@ -44,7 +44,7 @@ import tensor ...@@ -44,7 +44,7 @@ import tensor
import misc.safe_asarray as safe_asarray import misc.safe_asarray as safe_asarray
from tensor import opt, TensorType from tensor import opt, TensorType
import gof import gof
from gof import Optimizer, toolbox, Op, Apply from gof import Optimizer, toolbox, Op, Apply, Variable
from compile import optdb, SharedVariable, function, Param from compile import optdb, SharedVariable, function, Param
import compile import compile
import gradient import gradient
...@@ -1559,8 +1559,15 @@ class Scan(Op): ...@@ -1559,8 +1559,15 @@ class Scan(Op):
theano.config.floatX)) theano.config.floatX))
inner_gfn_ins = inner_g_outs + self.inputs inner_gfn_ins = inner_g_outs + self.inputs
g_args = [self.n_steps] + g_outs[:self.n_outs_not_shared] \
+ scan_outputs + args[1:] # Make sure you don't have numbers in here
if not isinstance(self.n_steps, Variable):
n_steps = tensor.as_tensor(self.n_steps)
else:
n_steps = self.n_steps
g_args = [n_steps] + g_outs[:self.n_outs_not_shared] \
+ scan_outputs + args[1:]
truncate_gradient = self.truncate_gradient truncate_gradient = self.truncate_gradient
for x in self.store_steps[:self.n_outs_not_shared]: for x in self.store_steps[:self.n_outs_not_shared]:
if x>0 : if x>0 :
...@@ -1571,6 +1578,7 @@ class Scan(Op): ...@@ -1571,6 +1578,7 @@ class Scan(Op):
self.n_seqs, self.n_outs, self.n_outs_not_shared, self.n_seqs, self.n_outs, self.n_outs_not_shared,
self.go_backwards, self.seqs_taps, self.outs_taps, self.go_backwards, self.seqs_taps, self.outs_taps,
truncate_gradient) truncate_gradient)
g_scan_outs = g_scan(g_args) g_scan_outs = g_scan(g_args)
if not type(g_scan_outs) in (list, tuple): if not type(g_scan_outs) in (list, tuple):
g_scan_outs = [ g_scan_outs ] g_scan_outs = [ g_scan_outs ]
......
...@@ -1039,6 +1039,34 @@ class T_Scan(unittest.TestCase): ...@@ -1039,6 +1039,34 @@ class T_Scan(unittest.TestCase):
assert updates[b].type.ndim == b.type.ndim assert updates[b].type.ndim == b.type.ndim
def test_scan_as_tensor_on_gradients(self):
"""
Bug reported by cityhall on scan when computing the gradients
"""
to_scan = theano.tensor.dvector('to_scan')
seq = theano.tensor.dmatrix('seq')
f1 = theano.tensor.dscalar('f1')
def scanStep(prev, seq, f1):
return prev + f1 * seq
scanned, _ = theano.scan(fn = scanStep, \
sequences = [seq], \
outputs_info = [to_scan], \
non_sequences = [f1])
f_scan = theano.function(inputs=[to_scan, seq, f1], outputs=scanned)
f_scan([1,2,3], numpy.arange(12).reshape([4,3]), 1.)
t_grad = theano.tensor.grad(scanned.sum(), wrt=[to_scan, f1],
consider_constant=[seq])
f_grad = theano.function(inputs=[to_scan, seq, f1], outputs=t_grad)
f_scan([1,2,3], numpy.arange(12).reshape([4,3]), 1.)
f_grad([1,2,3], numpy.arange(12).reshape([4,3]), 1.)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论