提交 1518d595 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

syncing the dirty copy of tests : test_scan2.py

上级 86a481ba
...@@ -1808,6 +1808,27 @@ class T_Scan(object): ...@@ -1808,6 +1808,27 @@ class T_Scan(object):
assert len(f2.maker.env.outputs) == 5 assert len(f2.maker.env.outputs) == 5
def test_remove_stuff(self):
x = theano.tensor.vector()
def lm(m):
trng = theano.tensor.shared_randomstreams.RandomStreams(
utt.fetch_seed())
return [ 2*m+ trng.uniform(low =-1.1, high =1.1,
dtype = theano.config.floatX),
m + trng.uniform(size=[3])]
[o1,o2], updates = theano.scan( lm,
sequences = x,
n_steps = None,
truncate_gradient = -1,
go_backwards = False)
go1 = theano.tensor.grad(o1.mean(), wrt = x)
f = theano.function([x],o1, updates = updates,
allow_input_downcast = True)
theano.printing.pydotprint(f, 'ff.png', high_contrast=True)
print f([1,2,3,4,5])
if __name__ == '__main__': if __name__ == '__main__':
''' '''
print ' Use nosetests to run these tests ' print ' Use nosetests to run these tests '
...@@ -1881,8 +1902,8 @@ if __name__ == '__main__': ...@@ -1881,8 +1902,8 @@ if __name__ == '__main__':
print 19 print 19
scan_tst.test_grad_multiple_outs_taps_backwards() scan_tst.test_grad_multiple_outs_taps_backwards()
#''' #'''
print 20 #print 19.5
scan_tst.test_grad_multiple_outs_some_uncomputable() #scan_tst.test_remove_stuff()
#''' #'''
print 21 print 21
scan_tst.test_grad_multiple_outs_some_truncate() scan_tst.test_grad_multiple_outs_some_truncate()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论