提交 fdeda4f8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed a bug reported by cityhall consisting in not calling tensor.as_tensor on…

Fixed a bug reported by cityhall consisting in not calling tensor.as_tensor on arguments before calling make_node
上级 79ebf621
......@@ -44,7 +44,7 @@ import tensor
import misc.safe_asarray as safe_asarray
from tensor import opt, TensorType
import gof
from gof import Optimizer, toolbox, Op, Apply
from gof import Optimizer, toolbox, Op, Apply, Variable
from compile import optdb, SharedVariable, function, Param
import compile
import gradient
......@@ -1559,8 +1559,15 @@ class Scan(Op):
theano.config.floatX))
inner_gfn_ins = inner_g_outs + self.inputs
g_args = [self.n_steps] + g_outs[:self.n_outs_not_shared] \
+ scan_outputs + args[1:]
# Make sure you don't have numbers in here
if not isinstance(self.n_steps, Variable):
n_steps = tensor.as_tensor(self.n_steps)
else:
n_steps = self.n_steps
g_args = [n_steps] + g_outs[:self.n_outs_not_shared] \
+ scan_outputs + args[1:]
truncate_gradient = self.truncate_gradient
for x in self.store_steps[:self.n_outs_not_shared]:
if x>0 :
......@@ -1571,6 +1578,7 @@ class Scan(Op):
self.n_seqs, self.n_outs, self.n_outs_not_shared,
self.go_backwards, self.seqs_taps, self.outs_taps,
truncate_gradient)
g_scan_outs = g_scan(g_args)
if not type(g_scan_outs) in (list, tuple):
g_scan_outs = [ g_scan_outs ]
......
......@@ -1039,6 +1039,34 @@ class T_Scan(unittest.TestCase):
assert updates[b].type.ndim == b.type.ndim
def test_scan_as_tensor_on_gradients(self):
"""
Bug reported by cityhall on scan when computing the gradients
"""
to_scan = theano.tensor.dvector('to_scan')
seq = theano.tensor.dmatrix('seq')
f1 = theano.tensor.dscalar('f1')
def scanStep(prev, seq, f1):
return prev + f1 * seq
scanned, _ = theano.scan(fn = scanStep, \
sequences = [seq], \
outputs_info = [to_scan], \
non_sequences = [f1])
f_scan = theano.function(inputs=[to_scan, seq, f1], outputs=scanned)
f_scan([1,2,3], numpy.arange(12).reshape([4,3]), 1.)
t_grad = theano.tensor.grad(scanned.sum(), wrt=[to_scan, f1],
consider_constant=[seq])
f_grad = theano.function(inputs=[to_scan, seq, f1], outputs=t_grad)
f_scan([1,2,3], numpy.arange(12).reshape([4,3]), 1.)
f_grad([1,2,3], numpy.arange(12).reshape([4,3]), 1.)
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论