提交 3379fb05 authored 作者: Reyhane Askari's avatar Reyhane Askari

minor fixes

上级 ff8e99ed
...@@ -1007,13 +1007,13 @@ class Function(object): ...@@ -1007,13 +1007,13 @@ class Function(object):
return [i.variable for i in self.maker.inputs if i.implicit] return [i.variable for i in self.maker.inputs if i.implicit]
def sync_shared(self): def sync_shared(self):
for i, inp in enumerate(self.input_storage): if (hasattr(theano, "gpuarray") and
if i in self.maker.fgraph.update_mapping.values(): theano.gpuarray.pygpu_activated):
if (hasattr(theano, "gpuarray") and import pygpu
theano.gpuarray.pygpu_activated): for i, inp in enumerate(self.input_storage):
import pygpu if i in self.maker.fgraph.update_mapping.values():
if isinstance(inp.data, pygpu.gpuarray.GpuArray): if isinstance(inp.data, pygpu.gpuarray.GpuArray):
inp.data.sync() inp.data.sync()
# pickling/deepcopy support for Function # pickling/deepcopy support for Function
......
...@@ -911,18 +911,17 @@ def test_empty_givens_updates(): ...@@ -911,18 +911,17 @@ def test_empty_givens_updates():
def test_sync(): def test_sync():
if theano.config.device == 'cuda' and theano.gpuarray.pygpu_activated: if theano.config.device == 'cuda' and theano.gpuarray.pygpu_activated:
x = T.fmatrix('x') x = T.fmatrix('x')
w = theano.shared(np.random.rand(300, 500).astype('float32'), 'w') w = theano.shared(np.random.rand(2000, 2000).astype('float32'), 'w')
b = theano.shared(np.zeros((500)).astype('float32'), 'b') b = theano.shared(np.zeros((2000)).astype('float32'), 'b')
y = T.dot(x, w) + b.dimshuffle('x', 0) y = T.dot(x, x) + b.dimshuffle('x', 0)
updates = [(w, w + T.sum(T.dot(x, w) + updates = [(w, w + T.dot(w, x) + T.dot(w, w))]
T.dot(5 * x, 2 * w)))]
f = theano.function([x], y, updates=updates) f = theano.function([x], y, updates=updates)
f.sync_shared() f.sync_shared()
g = theano.function([x], y, updates=updates) g = theano.function([x], y, updates=updates)
x_ = np.random.rand(100, 300).astype('float32') x_ = np.random.rand(2000, 2000).astype('float32')
f(x_) f(x_)
g(x_) g(x_)
t_0 = time.time() t_0 = time.time()
...@@ -933,6 +932,8 @@ def test_sync(): ...@@ -933,6 +932,8 @@ def test_sync():
g(x_) g(x_)
t_2 = time.time() t_2 = time.time()
assert (t_1 - t_0) > (t_2 - t_1) assert (t_1 - t_0) > (t_2 - t_1)
else:
raise SkipTest("Sync is only availble when device is cuda.")
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论