提交 d55ec220 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2514 from sebastien-j/faster_tests_2

Speed up scan test
......@@ -4,6 +4,7 @@ import sys
from tempfile import mkdtemp
import time
import unittest
import copy
import cPickle
import numpy
......@@ -1909,20 +1910,25 @@ class T_Scan(unittest.TestCase):
gparams = theano.grad(cost, params)
updates = [(param, param - gparam * learning_rate)
for param, gparam in zip(params, gparams)]
mode = copy.copy(theano.compile.get_default_mode())
mode.check_py_code = False
learn_rnn_fn = theano.function(inputs=[x, t],
outputs=cost,
updates=updates)
updates=updates,
mode=mode)
eval_rnn_fn = theano.function(inputs=[x],
outputs=y)
outputs=y,
mode=mode)
# artificial data
x_v = numpy.arange(0., 100., 0.21, dtype=theano.config.floatX)
x_v = numpy.arange(0., 10.49, 0.21, dtype=theano.config.floatX)
x_v = x_v.reshape(len(x_v), 1)
s_v = numpy.sin(x_v)
t_v = numpy.roll(s_v, -1)[:-1]
s_v = s_v[:-1]
for i in xrange(100):
cost = learn_rnn_fn(s_v, t_v)
print i, cost
pred = eval_rnn_fn(s_v)
assert cost < 0.02
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论