提交 ff8e99ed authored 作者: Reyhane Askari's avatar Reyhane Askari

fix for pygpu

上级 65465106
......@@ -13,7 +13,6 @@ from itertools import chain
import time
import warnings
import numpy as np
import pygpu
import theano
from theano import config, gof
......@@ -1011,8 +1010,10 @@ class Function(object):
for i, inp in enumerate(self.input_storage):
if i in self.maker.fgraph.update_mapping.values():
if (hasattr(theano, "gpuarray") and
isinstance(inp.data, pygpu.gpuarray.GpuArray)):
inp.data.sync()
theano.gpuarray.pygpu_activated):
import pygpu
if isinstance(inp.data, pygpu.gpuarray.GpuArray):
inp.data.sync()
# pickling/deepcopy support for Function
......
......@@ -909,7 +909,7 @@ def test_empty_givens_updates():
def test_sync():
if theano.config.device == 'cuda':
if theano.config.device == 'cuda' and theano.gpuarray.pygpu_activated:
x = T.fmatrix('x')
w = theano.shared(np.random.rand(300, 500).astype('float32'), 'w')
b = theano.shared(np.zeros((500)).astype('float32'), 'b')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论