提交 84b61a6c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Tests for MRG float16

上级 5647b421
......@@ -9,6 +9,7 @@ from theano.configparser import change_flags
from theano.sandbox import rng_mrg
from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.sandbox.tests.test_rng_mrg import java_samples, rng_mrg_overflow
from theano.sandbox.tests.test_rng_mrg import test_f16_nonzero as cpu_f16_nonzero
from theano.tests import unittest_tools as utt
from .config import mode_with_gpu as mode
......@@ -162,3 +163,7 @@ def test_validate_input_types_gpuarray_backend():
rstate = np.zeros((7, 6), dtype="int32")
rstate = gpuarray_shared_constructor(rstate)
rng_mrg.mrg_uniform.new(rstate, ndim=None, dtype="float32", size=(3,))
def test_f16_nonzero():
cpu_f16_nonzero(mode=mode, op_to_check=GPUA_mrg_uniform)
......@@ -751,6 +751,16 @@ def test_undefined_grad():
(avg, std))
def test_f16_nonzero(mode=None, op_to_check=rng_mrg.mrg_uniform):
srng = MRG_RandomStreams(seed=utt.fetch_seed())
m = srng.uniform(size=(1000, 1000), dtype='float16')
assert m.dtype == 'float16', m.type
f = theano.function([], m, mode=mode)
assert any(isinstance(n.op, op_to_check) for n in f.maker.fgraph.apply_nodes)
m_val = f()
assert np.all((0 < m_val) & (m_val < 1))
if __name__ == "__main__":
rng = MRG_RandomStreams(np.random.randint(2147462579))
print(theano.__file__)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论