提交 eed7ee4b authored 作者: James Bergstra's avatar James Bergstra

added test_scan test_speed_batchrnn for Ilya

上级 faef4935
...@@ -2569,6 +2569,67 @@ def test_speed_rnn(): ...@@ -2569,6 +2569,67 @@ def test_speed_rnn():
print 'theano (updates, cvm)', t3 - t2 print 'theano (updates, cvm)', t3 - t2
print shared_r.get_value() print shared_r.get_value()
def test_speed_batchrnn():
#
# This function prints out the speed of recurrent neural network
# calculations implemented in various ways. In DebugMode this will test the
# correctness of the optimizations applied, but generally
# correctness-testing is not the goal of this test.
#
# To be honest, it isn't really a unit test so much as a tool for testing
# approaches to scan.
#
# The computation being tested here is a repeated tanh of a matrix-vector
# multiplication - the heart of an ESN or RNN.
#
import theano.scalar.sharedvar
print """Warning: the updates version runs slower than python because by
default the blas optimizations don't replace dot with dot22. Why is that?"""
L = 100
B = 50
N = 400
numpy.random.seed(2523452)
r = numpy.arange(B*L*N).astype(theano.config.floatX).reshape(L,B,N)
w = numpy.random.randn(N,N).astype(theano.config.floatX)
t0 = time.time()
for i in xrange(1,L):
r[i] = numpy.tanh(numpy.dot(r[i-1], w))
t1 = time.time()
print 'python', t1 - t0
if 1:
r = numpy.arange(B*L*N).astype(theano.config.floatX).reshape(L,B,N)
s_w = theano.shared(w)
shared_r = theano.shared(r)
s_i = theano.scalar.sharedvar.shared(1)
s_rinc = tensor.inc_subtensor(
shared_r[s_i],
theano.tensor.tanh(
theano.tensor.dot(
shared_r[s_i-1],
w)),
tolerate_inplace_aliasing=True)
f = theano.function([], [],
updates={
s_i: s_i+1,
shared_r: s_rinc,
},
mode=theano.Mode(linker='cvm'))
theano.printing.debugprint(f )
f_fn = f.fn
print f_fn
t2 = time.time()
f_fn(n_calls=L-2)
f() #999 to update the profiling timers
t3 = time.time()
print 'theano (updates, cvm)', t3 - t2
if __name__ == '__main__': if __name__ == '__main__':
#''' #'''
print ' Use nosetests to run these tests ' print ' Use nosetests to run these tests '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论