提交 6a087050 authored 作者: Reyhane Askari's avatar Reyhane Askari

test_sync added

上级 616f8d37
......@@ -3,6 +3,7 @@ import copy
import six.moves.cPickle as pickle
import numpy as np
import unittest
import time
from theano import config, gof
......@@ -907,6 +908,31 @@ def test_empty_givens_updates():
function([theano.In(x)], y, updates={})
def test_sync():
x = T.fmatrix('x')
w = theano.shared(np.random.rand(300, 500).astype('float32'), 'w')
b = theano.shared(np.zeros((500)).astype('float32'), 'b')
y = T.dot(x, w) + b.dimshuffle('x', 0)
updates = [(w, w + T.sum(T.dot(x, w) +
T.dot(5 * x, 2 * w)))]
f = theano.function([x], y, updates=updates, sync=True)
g = theano.function([x], y, updates=updates, sync=False)
x_ = np.random.rand(100, 300).astype('float32')
f(x_)
g(x_)
t_0 = time.time()
for i in range(1000):
f(x_)
t_1 = time.time()
for i in range(1000):
g(x_)
t_2 = time.time()
assert (t_1 - t_0) > (t_2 - t_1)
if __name__ == '__main__':
if 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论