提交 229d9db9 authored 作者: David Warde-Farley's avatar David Warde-Farley

TST: curand_rng for symbolic single dimension

上级 c39a88df
import numpy import numpy
import theano import theano
from theano.tensor import vector from theano.tensor import vector, constant
from theano.sandbox.cuda.rng_curand import CURAND_RandomStreams from theano.sandbox.cuda.rng_curand import CURAND_RandomStreams
from theano.sandbox.rng_mrg import MRG_RandomStreams from theano.sandbox.rng_mrg import MRG_RandomStreams
...@@ -10,11 +10,15 @@ else: ...@@ -10,11 +10,15 @@ else:
mode_with_gpu = theano.compile.mode.get_default_mode().including('gpu') mode_with_gpu = theano.compile.mode.get_default_mode().including('gpu')
def check_uniform_basic(shape_as_theano_variable): def check_uniform_basic(shape_as_theano_variable,
dim_as_theano_variable=False):
rng = CURAND_RandomStreams(234) rng = CURAND_RandomStreams(234)
if shape_as_theano_variable: if shape_as_theano_variable:
shape = vector(dtype='int64') shape = vector(dtype='int64')
givens = {shape: (10, 10)} givens = {shape: (10, 10)}
else:
if dim_as_theano_variable:
shape = (10, constant(10))
else: else:
shape = (10, 10) shape = (10, 10)
givens = {} givens = {}
...@@ -43,15 +47,19 @@ def check_uniform_basic(shape_as_theano_variable): ...@@ -43,15 +47,19 @@ def check_uniform_basic(shape_as_theano_variable):
def test_uniform_basic(): def test_uniform_basic():
yield check_uniform_basic, True
yield check_uniform_basic, False yield check_uniform_basic, False
yield check_uniform_basic, False, True
yield check_uniform_basic, True
def check_normal_basic(shape_as_theano_variable,
def check_normal_basic(shape_as_theano_variable): dim_as_theano_variable=False):
rng = CURAND_RandomStreams(234) rng = CURAND_RandomStreams(234)
if shape_as_theano_variable: if shape_as_theano_variable:
shape = vector(dtype='int64') shape = vector(dtype='int64')
givens = {shape: (10, 10)} givens = {shape: (10, 10)}
else:
if dim_as_theano_variable:
shape = (10, constant(10))
else: else:
shape = (10, 10) shape = (10, 10)
givens = {} givens = {}
...@@ -78,8 +86,9 @@ def check_normal_basic(shape_as_theano_variable): ...@@ -78,8 +86,9 @@ def check_normal_basic(shape_as_theano_variable):
def test_normal_basic(): def test_normal_basic():
yield check_normal_basic, True
yield check_normal_basic, False yield check_normal_basic, False
yield check_normal_basic, False, True
yield check_normal_basic, True
def compare_speed(): def compare_speed():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论