提交 059f00c4 authored 作者: David Warde-Farley's avatar David Warde-Farley

TST: curand_rng: specify_shape for symbolic shapes

上级 229d9db9
import numpy import numpy
import theano import theano
from theano.tensor import vector, constant from theano.tensor import vector, constant, specify_shape
from theano.sandbox.cuda.rng_curand import CURAND_RandomStreams from theano.sandbox.cuda.rng_curand import CURAND_RandomStreams
from theano.sandbox.rng_mrg import MRG_RandomStreams from theano.sandbox.rng_mrg import MRG_RandomStreams
...@@ -14,7 +14,7 @@ def check_uniform_basic(shape_as_theano_variable, ...@@ -14,7 +14,7 @@ def check_uniform_basic(shape_as_theano_variable,
dim_as_theano_variable=False): dim_as_theano_variable=False):
rng = CURAND_RandomStreams(234) rng = CURAND_RandomStreams(234)
if shape_as_theano_variable: if shape_as_theano_variable:
shape = vector(dtype='int64') shape = specify_shape(vector(dtype='int64'), (2,))
givens = {shape: (10, 10)} givens = {shape: (10, 10)}
else: else:
if dim_as_theano_variable: if dim_as_theano_variable:
...@@ -55,7 +55,7 @@ def check_normal_basic(shape_as_theano_variable, ...@@ -55,7 +55,7 @@ def check_normal_basic(shape_as_theano_variable,
dim_as_theano_variable=False): dim_as_theano_variable=False):
rng = CURAND_RandomStreams(234) rng = CURAND_RandomStreams(234)
if shape_as_theano_variable: if shape_as_theano_variable:
shape = vector(dtype='int64') shape = specify_shape(vector(dtype='int64'), (2,))
givens = {shape: (10, 10)} givens = {shape: (10, 10)}
else: else:
if dim_as_theano_variable: if dim_as_theano_variable:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论