提交 13b199a3 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

hopefully final fix to scsan with gpu ( plus test)

上级 6828831c
......@@ -1214,7 +1214,12 @@ def gpuScanOptimization(node):
, info
, typeConstructor = typeConstructor
).make_node(*nw_ins).outputs
outputs = [safe_to_cpu(x) for x in _outputs]
outputs = []
for x,y in zip(_outputs, node.outputs):
if isinstance(y.type, CudaNdarrayType):
outputs += [x]
else:
outputs += [safe_to_cpu(x)]
return outputs
return False
......
......@@ -92,7 +92,7 @@ def traverse(out, x,x_copy, d):
inner graph of scan '''
import theano.sandbox.cuda as cuda
if out == x:
d[out] = tensor.as_tensor_variable(x_copy)
d[out] = cuda.gpu_from_host(x_copy)
return d
elif out.owner is None:
return d
......
......@@ -906,6 +906,50 @@ class T_Scan(unittest.TestCase):
theano_v = my_f()
assert numpy.allclose( theano_v , numpy_v[5:,:])
def test_cuda_gibbs_chain(self):
import theano
from nose.plugins.skip import SkipTest
from theano.sandbox import cuda
if cuda.cuda_available == False:
raise SkipTest('Optional package cuda disabled')
if theano.config.mode == 'FAST_COMPILE':
mode = theano.compile.mode.get_mode('FAST_RUN')
else:
mode = theano.compile.mode.get_default_mode()
mode = mode.including('gpu','scan')
rng = numpy.random.RandomState(utt.fetch_seed())
v_vsample = numpy.array(rng.binomial(1,0.5, size=(3,20), )
, dtype = 'float32')
vsample = theano.shared(v_vsample)
import theano.sandbox.rng_mrg
trng = theano.sandbox.rng_mrg.MRG_RandomStreams(
utt.fetch_seed())
def f(vsample_tm1):
return trng.binomial(vsample_tm1.shape, n=1, p = 0.3,
dtype='float32')*vsample_tm1
theano_vsamples, updates = theano.scan(f, [], vsample, []
, n_steps = 10
, truncate_gradient=-1
, go_backwards = False
, mode = mode
)
gout = theano.tensor.grad(theano_vsamples[-1].sum(), vsample)
my_f = theano.function([], theano_vsamples[-1]
, updates = updates
, allow_input_downcast = True
, mode = mode
)
# I leave this to tested by debugmode, this test was anyway more of
# doest the graph compile kind of test
t_result = my_f()
def test_gibbs_chain(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论