提交 649b9361 authored 作者: Yikang Shen's avatar Yikang Shen

test MRG random stream with device=cuda and target=cpu, shared variable +…

test MRG random stream with device=cuda and target=cpu, shared variable + computation in a real model stay on CPU. #6382
上级 404cea07
......@@ -14,7 +14,7 @@ from theano.tensor import as_tensor_variable, get_vector_length
from theano.scalar import int32 as int_t
from .basic_ops import (GpuKernelBase, Kernel, infer_context_name,
host_from_gpu, as_gpuarray_variable)
GpuFromHost, host_from_gpu, as_gpuarray_variable)
from .type import GpuArrayType, gpu_context_type
from .fp16_help import write_w
from .opt import register_opt, register_opt2
......@@ -309,7 +309,9 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
@register_opt2([mrg_uniform], 'fast_compile')
def local_gpua_mrg_graph(op, context_name, inputs, outputs):
if (type(op) == mrg_uniform and
isinstance(inputs[0].type, GpuArrayType)):
isinstance(inputs[0].type, GpuArrayType) and
not isinstance(inputs[0].owner.op, GpuFromHost)
):
outs = GPUA_mrg_uniform.new(inputs[0],
op.output_type.ndim,
op.output_type.dtype,
......
......@@ -11,7 +11,7 @@ from six.moves import xrange
import theano
from theano import change_flags, config, tensor
from theano.sandbox import rng_mrg
from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.sandbox.rng_mrg import MRG_RandomStreams, mrg_uniform
from theano.tests import unittest_tools as utt
from theano.tests.unittest_tools import attr
......@@ -760,6 +760,21 @@ def test_target_parameter():
basic_target_parameter_test(srng.multinomial_wo_replacement(pvals=pvals.astype('float32'), target='cpu'))
def test_cpu_target_with_shared_variable():
srng = MRG_RandomStreams()
x = theano.shared(np.random.rand(2,3).astype('float32'), name='x')
y = srng.uniform(x.shape, target='cpu')
y.name = 'y'
z = (x * y).sum()
z.name = 'z'
fz = theano.function([], z)
nodes = fz.maker.fgraph.toposort()
assert any([isinstance(node.op, mrg_uniform) for node in nodes])
if __name__ == "__main__":
rng = MRG_RandomStreams(np.random.randint(2147462579))
print(theano.__file__)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论