提交 c39a88df authored 作者: David Warde-Farley's avatar David Warde-Farley

TST: test CURAND rng when shape is a Variable

Add tests for all the branches involving Variable that weren't tested and thus not catching that Variable was undefined.
上级 d309a022
import numpy
import theano
from theano.tensor import vector
from theano.sandbox.cuda.rng_curand import CURAND_RandomStreams
from theano.sandbox.rng_mrg import MRG_RandomStreams
......@@ -9,14 +10,19 @@ else:
mode_with_gpu = theano.compile.mode.get_default_mode().including('gpu')
def test_uniform_basic():
def check_uniform_basic(shape_as_theano_variable):
rng = CURAND_RandomStreams(234)
u0 = rng.uniform((10, 10))
u1 = rng.uniform((10, 10))
f0 = theano.function([], u0, mode=mode_with_gpu)
f1 = theano.function([], u1, mode=mode_with_gpu)
if shape_as_theano_variable:
shape = vector(dtype='int64')
givens = {shape: (10, 10)}
else:
shape = (10, 10)
givens = {}
u0 = rng.uniform(shape)
u1 = rng.uniform(shape)
f0 = theano.function([], u0, mode=mode_with_gpu, givens=givens)
f1 = theano.function([], u1, mode=mode_with_gpu, givens=givens)
v0list = [f0() for i in range(3)]
v1list = [f1() for i in range(3)]
......@@ -36,14 +42,24 @@ def test_uniform_basic():
assert .25 <= v.mean() <= .75
def test_normal_basic():
rng = CURAND_RandomStreams(234)
def test_uniform_basic():
yield check_uniform_basic, True
yield check_uniform_basic, False
u0 = rng.normal((10, 10))
u1 = rng.normal((10, 10))
f0 = theano.function([], u0, mode=mode_with_gpu)
f1 = theano.function([], u1, mode=mode_with_gpu)
def check_normal_basic(shape_as_theano_variable):
rng = CURAND_RandomStreams(234)
if shape_as_theano_variable:
shape = vector(dtype='int64')
givens = {shape: (10, 10)}
else:
shape = (10, 10)
givens = {}
u0 = rng.normal(shape)
u1 = rng.normal(shape)
f0 = theano.function([], u0, mode=mode_with_gpu, givens=givens)
f1 = theano.function([], u1, mode=mode_with_gpu, givens=givens)
v0list = [f0() for i in range(3)]
v1list = [f1() for i in range(3)]
......@@ -61,6 +77,11 @@ def test_normal_basic():
assert -.5 <= v.mean() <= .5
def test_normal_basic():
yield check_normal_basic, True
yield check_normal_basic, False
def compare_speed():
# To run this speed comparison
# cd <directory of this file>
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论