提交 a9c0e1ac authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make the scan tests pass

上级 32a13f0e
......@@ -6,10 +6,10 @@ import theano
from theano.tests import unittest_tools as utt
import theano.sandbox.rng_mrg
from ..basic_ops import gpu_from_host, GpuFromHost, HostFromGpu
from ..basic_ops import GpuFromHost, HostFromGpu
from ..elemwise import GpuElemwise
from .config import mode_with_gpu
from .config import mode_with_gpu, test_ctx_name
class T_Scan(TestCase):
......@@ -35,7 +35,7 @@ class T_Scan(TestCase):
go_backwards=False,
mode=mode)
output = gpu_from_host(output)
output = GpuFromHost(test_ctx_name)(output)
f2 = theano.function([u, x0, W_in, W],
output,
updates=updates,
......
......@@ -4854,6 +4854,12 @@ class T_Scan_Gpuarray(unittest.TestCase, ScanGpuTests):
def __init__(self, *args, **kwargs):
from theano.sandbox import gpuarray
self.gpu_backend = gpuarray
# This is unfortunate, but required
def gpu_from_host(v):
return gpuarray.GpuFromHost(None)(v)
self.gpu_backend.gpu_from_host = gpu_from_host
self.mode_with_gpu = mode_with_opt.including('gpuarray', 'scan')
self.mode_with_gpu_nodebug = mode_nodebug.including('gpuarray', 'scan')
super(T_Scan_Gpuarray, self).__init__(*args, **kwargs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论