提交 a9c0e1ac authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make the scan tests pass

上级 32a13f0e
...@@ -6,10 +6,10 @@ import theano ...@@ -6,10 +6,10 @@ import theano
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
import theano.sandbox.rng_mrg import theano.sandbox.rng_mrg
from ..basic_ops import gpu_from_host, GpuFromHost, HostFromGpu from ..basic_ops import GpuFromHost, HostFromGpu
from ..elemwise import GpuElemwise from ..elemwise import GpuElemwise
from .config import mode_with_gpu from .config import mode_with_gpu, test_ctx_name
class T_Scan(TestCase): class T_Scan(TestCase):
...@@ -35,7 +35,7 @@ class T_Scan(TestCase): ...@@ -35,7 +35,7 @@ class T_Scan(TestCase):
go_backwards=False, go_backwards=False,
mode=mode) mode=mode)
output = gpu_from_host(output) output = GpuFromHost(test_ctx_name)(output)
f2 = theano.function([u, x0, W_in, W], f2 = theano.function([u, x0, W_in, W],
output, output,
updates=updates, updates=updates,
......
...@@ -4854,6 +4854,12 @@ class T_Scan_Gpuarray(unittest.TestCase, ScanGpuTests): ...@@ -4854,6 +4854,12 @@ class T_Scan_Gpuarray(unittest.TestCase, ScanGpuTests):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
from theano.sandbox import gpuarray from theano.sandbox import gpuarray
self.gpu_backend = gpuarray self.gpu_backend = gpuarray
# This is unfortunate, but required
def gpu_from_host(v):
return gpuarray.GpuFromHost(None)(v)
self.gpu_backend.gpu_from_host = gpu_from_host
self.mode_with_gpu = mode_with_opt.including('gpuarray', 'scan') self.mode_with_gpu = mode_with_opt.including('gpuarray', 'scan')
self.mode_with_gpu_nodebug = mode_nodebug.including('gpuarray', 'scan') self.mode_with_gpu_nodebug = mode_nodebug.including('gpuarray', 'scan')
super(T_Scan_Gpuarray, self).__init__(*args, **kwargs) super(T_Scan_Gpuarray, self).__init__(*args, **kwargs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论