提交 ff8e99ed authored 作者: Reyhane Askari's avatar Reyhane Askari

fix for pygpu

上级 65465106
...@@ -13,7 +13,6 @@ from itertools import chain ...@@ -13,7 +13,6 @@ from itertools import chain
import time import time
import warnings import warnings
import numpy as np import numpy as np
import pygpu
import theano import theano
from theano import config, gof from theano import config, gof
...@@ -1011,8 +1010,10 @@ class Function(object): ...@@ -1011,8 +1010,10 @@ class Function(object):
for i, inp in enumerate(self.input_storage): for i, inp in enumerate(self.input_storage):
if i in self.maker.fgraph.update_mapping.values(): if i in self.maker.fgraph.update_mapping.values():
if (hasattr(theano, "gpuarray") and if (hasattr(theano, "gpuarray") and
isinstance(inp.data, pygpu.gpuarray.GpuArray)): theano.gpuarray.pygpu_activated):
inp.data.sync() import pygpu
if isinstance(inp.data, pygpu.gpuarray.GpuArray):
inp.data.sync()
# pickling/deepcopy support for Function # pickling/deepcopy support for Function
......
...@@ -909,7 +909,7 @@ def test_empty_givens_updates(): ...@@ -909,7 +909,7 @@ def test_empty_givens_updates():
def test_sync(): def test_sync():
if theano.config.device == 'cuda': if theano.config.device == 'cuda' and theano.gpuarray.pygpu_activated:
x = T.fmatrix('x') x = T.fmatrix('x')
w = theano.shared(np.random.rand(300, 500).astype('float32'), 'w') w = theano.shared(np.random.rand(300, 500).astype('float32'), 'w')
b = theano.shared(np.zeros((500)).astype('float32'), 'b') b = theano.shared(np.zeros((500)).astype('float32'), 'b')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论