提交 d55ec220 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2514 from sebastien-j/faster_tests_2

Speed up scan test
...@@ -4,6 +4,7 @@ import sys ...@@ -4,6 +4,7 @@ import sys
from tempfile import mkdtemp from tempfile import mkdtemp
import time import time
import unittest import unittest
import copy
import cPickle import cPickle
import numpy import numpy
...@@ -1909,20 +1910,25 @@ class T_Scan(unittest.TestCase): ...@@ -1909,20 +1910,25 @@ class T_Scan(unittest.TestCase):
gparams = theano.grad(cost, params) gparams = theano.grad(cost, params)
updates = [(param, param - gparam * learning_rate) updates = [(param, param - gparam * learning_rate)
for param, gparam in zip(params, gparams)] for param, gparam in zip(params, gparams)]
mode = copy.copy(theano.compile.get_default_mode())
mode.check_py_code = False
learn_rnn_fn = theano.function(inputs=[x, t], learn_rnn_fn = theano.function(inputs=[x, t],
outputs=cost, outputs=cost,
updates=updates) updates=updates,
mode=mode)
eval_rnn_fn = theano.function(inputs=[x], eval_rnn_fn = theano.function(inputs=[x],
outputs=y) outputs=y,
mode=mode)
# artificial data # artificial data
x_v = numpy.arange(0., 100., 0.21, dtype=theano.config.floatX) x_v = numpy.arange(0., 10.49, 0.21, dtype=theano.config.floatX)
x_v = x_v.reshape(len(x_v), 1) x_v = x_v.reshape(len(x_v), 1)
s_v = numpy.sin(x_v) s_v = numpy.sin(x_v)
t_v = numpy.roll(s_v, -1)[:-1] t_v = numpy.roll(s_v, -1)[:-1]
s_v = s_v[:-1] s_v = s_v[:-1]
for i in xrange(100): for i in xrange(100):
cost = learn_rnn_fn(s_v, t_v) cost = learn_rnn_fn(s_v, t_v)
print i, cost
pred = eval_rnn_fn(s_v) pred = eval_rnn_fn(s_v)
assert cost < 0.02 assert cost < 0.02
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论