提交 b0a7305b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added a test that fails only in debug mode for unknown reasons right now.

上级 1518d595
...@@ -1826,20 +1826,26 @@ class T_Scan(unittest.TestCase): ...@@ -1826,20 +1826,26 @@ class T_Scan(unittest.TestCase):
def test_remove_stuff(self): def test_remove_stuff(self):
trng = theano.tensor.shared_randomstreams.RandomStreams( x = theano.tensor.vector()
def lm(m):
trng = theano.tensor.shared_randomstreams.RandomStreams(
utt.fetch_seed()) utt.fetch_seed())
return [ 2*m+ trng.uniform(low =-1.1, high =1.1,
dtype = theano.config.floatX),
m + trng.uniform(size=[3])]
x = theano.tensor.vector() [o1,o2], updates = theano.scan( lm,
[o1,o2], updates = theano.scan( lambda m:
[2*m+trng.uniform(),m+trng.uniform()],
sequences = x, sequences = x,
n_steps = None, n_steps = None,
truncate_gradient = -1, truncate_gradient = -1,
go_backwards = False) go_backwards = False)
go1 = theano.tensor.grad(o1.mean(), wrt = x)
f = theano.function([x],go1, updates = updates,
allow_input_downcast = True)
print f([1,2,3])
f = theano.function([x],o1, allow_input_downcast = True)
f([1,2,3,4,5])
if __name__ == '__main__': if __name__ == '__main__':
#''' #'''
......
...@@ -1824,10 +1824,9 @@ class T_Scan(object): ...@@ -1824,10 +1824,9 @@ class T_Scan(object):
go_backwards = False) go_backwards = False)
go1 = theano.tensor.grad(o1.mean(), wrt = x) go1 = theano.tensor.grad(o1.mean(), wrt = x)
f = theano.function([x],o1, updates = updates, f = theano.function([x],go1, updates = updates,
allow_input_downcast = True) allow_input_downcast = True)
theano.printing.pydotprint(f, 'ff.png', high_contrast=True) print f([1,2,3])
print f([1,2,3,4,5])
if __name__ == '__main__': if __name__ == '__main__':
''' '''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论