提交 50f66cc6 authored 作者: Reyhane Askari's avatar Reyhane Askari

added check for gpuarray

上级 6a087050
......@@ -13,6 +13,7 @@ from itertools import chain
import time
import warnings
import numpy as np
import pygpu
import theano
from theano import config, gof
......@@ -1693,6 +1694,8 @@ class FunctionMaker(object):
if self.sync:
for i, inp in enumerate(input_storage):
if i in self.fgraph.update_mapping.values():
if (hasattr(theano, "gpuarray") and
isinstance(inp.data, pygpu.gpuarray.GpuArray)):
inp.data.sync()
fn.profile = self.profile
......
......@@ -909,6 +909,7 @@ def test_empty_givens_updates():
def test_sync():
if theano.config.device == 'cuda':
x = T.fmatrix('x')
w = theano.shared(np.random.rand(300, 500).astype('float32'), 'w')
b = theano.shared(np.zeros((500)).astype('float32'), 'b')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论