提交 ee776388 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new test for mixture dtype outputs of scan

上级 12ff2b84
......@@ -504,6 +504,64 @@ class T_Scan(unittest.TestCase):
assert not any([isinstance(node.op, theano.sandbox.cuda.GpuFromHost)
for node in scan_node_topo])
# This third test checks that scan can deal with a mixture of dtypes as
# outputs when is running on GPU
def test_gpu3_mixture_dtype_outputs(self):
from theano.sandbox import cuda
if cuda.cuda_available == False:
raise SkipTest('Optional package cuda disabled')
def f_rnn(u_t, x_tm1, W_in, W):
return (u_t * W_in + x_tm1 * W,
tensor.cast(u_t+x_tm1, 'int64'))
u = theano.tensor.fvector('u')
x0 = theano.tensor.fscalar('x0')
W_in = theano.tensor.fscalar('win')
W = theano.tensor.fscalar('w')
output, updates = theano.scan(f_rnn,
u,
[x0, None],
[W_in, W],
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
mode=mode_with_gpu)
f2 = theano.function([u, x0, W_in, W],
output,
updates=updates,
allow_input_downcast=True,
mode=mode_with_gpu)
# get random initial values
rng = numpy.random.RandomState(utt.fetch_seed())
v_u = rng.uniform(size=(4,), low=-5., high=5.)
v_x0 = rng.uniform()
W = rng.uniform()
W_in = rng.uniform()
# compute the output in numpy
v_out1 = numpy.zeros((4,))
v_out2 = numpy.zeros((4,), dtype='int64')
v_out1[0] = v_u[0] * W_in + v_x0 * W
v_out2[0] = v_u[0] + v_x0
for step in xrange(1, 4):
v_out1[step] = v_u[step] * W_in + v_out1[step - 1] * W
v_out2[step] = numpy.int64(v_u[step] + v_out1[step - 1])
theano_out1, theano_out2 = f2(v_u, v_x0, W_in, W)
assert numpy.allclose(theano_out1, v_out1)
assert numpy.allclose(theano_out2, v_out2)
topo = f2.maker.fgraph.toposort()
scan_node = [node for node in topo
if isinstance(node.op, theano.scan_module.scan_op.Scan)]
assert len(scan_node) == 1
scan_node = scan_node[0]
assert scan_node.op.gpu
# simple rnn, one input, one state, weights for each; input/state
# are vectors, weights are scalars; using shared variables
def test_one_sequence_one_output_weights_shared(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论