提交 f9cc24da authored 作者: delallea's avatar delallea

Merge pull request #52 from nouiz/fix_curand_op

Fix curand op
...@@ -3,14 +3,20 @@ import theano ...@@ -3,14 +3,20 @@ import theano
from theano.sandbox.cuda.rng_curand import CURAND_RandomStreams from theano.sandbox.cuda.rng_curand import CURAND_RandomStreams
from theano.sandbox.rng_mrg import MRG_RandomStreams from theano.sandbox.rng_mrg import MRG_RandomStreams
if theano.config.mode == 'FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu')
else:
mode_with_gpu = theano.compile.mode.get_default_mode().including('gpu')
def test_uniform_basic(): def test_uniform_basic():
rng = CURAND_RandomStreams(234) rng = CURAND_RandomStreams(234)
u0 = rng.uniform((10,10)) u0 = rng.uniform((10, 10))
u1 = rng.uniform((10,10)) u1 = rng.uniform((10, 10))
f0 = theano.function([], u0) f0 = theano.function([], u0, mode=mode_with_gpu)
f1 = theano.function([], u1) f1 = theano.function([], u1, mode=mode_with_gpu)
v0list = [f0() for i in range(3)] v0list = [f0() for i in range(3)]
v1list = [f1() for i in range(3)] v1list = [f1() for i in range(3)]
...@@ -23,7 +29,7 @@ def test_uniform_basic(): ...@@ -23,7 +29,7 @@ def test_uniform_basic():
assert numpy.all(v0list[0] != v1list[0]) assert numpy.all(v0list[0] != v1list[0])
for v in v0list: for v in v0list:
assert v.shape == (10,10) assert v.shape == (10, 10)
assert v.min() >= 0 assert v.min() >= 0
assert v.max() <= 1 assert v.max() <= 1
assert v.min() < v.max() assert v.min() < v.max()
...@@ -33,11 +39,11 @@ def test_uniform_basic(): ...@@ -33,11 +39,11 @@ def test_uniform_basic():
def test_normal_basic(): def test_normal_basic():
rng = CURAND_RandomStreams(234) rng = CURAND_RandomStreams(234)
u0 = rng.normal((10,10)) u0 = rng.normal((10, 10))
u1 = rng.normal((10,10)) u1 = rng.normal((10, 10))
f0 = theano.function([], u0) f0 = theano.function([], u0, mode=mode_with_gpu)
f1 = theano.function([], u1) f1 = theano.function([], u1, mode=mode_with_gpu)
v0list = [f0() for i in range(3)] v0list = [f0() for i in range(3)]
v1list = [f1() for i in range(3)] v1list = [f1() for i in range(3)]
...@@ -50,7 +56,7 @@ def test_normal_basic(): ...@@ -50,7 +56,7 @@ def test_normal_basic():
assert numpy.all(v0list[0] != v1list[0]) assert numpy.all(v0list[0] != v1list[0])
for v in v0list: for v in v0list:
assert v.shape == (10,10) assert v.shape == (10, 10)
assert v.min() < v.max() assert v.min() < v.max()
assert -.5 <= v.mean() <= .5 assert -.5 <= v.mean() <= .5
...@@ -64,17 +70,17 @@ def compare_speed(): ...@@ -64,17 +70,17 @@ def compare_speed():
mrg = MRG_RandomStreams() mrg = MRG_RandomStreams()
crn = CURAND_RandomStreams(234) crn = CURAND_RandomStreams(234)
N=1000*100 N = 1000 * 100
dest = theano.shared(numpy.zeros(N,dtype=theano.config.floatX)) dest = theano.shared(numpy.zeros(N, dtype=theano.config.floatX))
mrg_u = theano.function([], [], updates={dest:mrg.uniform((N,))}, mrg_u = theano.function([], [], updates={dest: mrg.uniform((N,))},
profile='mrg uniform') profile='mrg uniform')
crn_u = theano.function([], [], updates={dest:crn.uniform((N,))}, crn_u = theano.function([], [], updates={dest: crn.uniform((N,))},
profile='crn uniform') profile='crn uniform')
mrg_n = theano.function([], [], updates={dest:mrg.normal((N,))}, mrg_n = theano.function([], [], updates={dest: mrg.normal((N,))},
profile='mrg normal') profile='mrg normal')
crn_n = theano.function([], [], updates={dest:crn.normal((N,))}, crn_n = theano.function([], [], updates={dest: crn.normal((N,))},
profile='crn normal') profile='crn normal')
for f in mrg_u, crn_u, mrg_n, crn_n: for f in mrg_u, crn_u, mrg_n, crn_n:
...@@ -86,6 +92,5 @@ def compare_speed(): ...@@ -86,6 +92,5 @@ def compare_speed():
for i in range(100): for i in range(100):
for f in mrg_u, crn_u, mrg_n, crn_n: for f in mrg_u, crn_u, mrg_n, crn_n:
# don't time the first call, it has some startup cost # don't time the first call, it has some startup cost
f.fn.time_thunks = (i>0) f.fn.time_thunks = (i > 0)
f() f()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论