Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f8414635
提交
f8414635
authored
3月 18, 2010
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rng_mgr - GPU version
上级
ddf3109c
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
307 行增加
和
22 行删除
+307
-22
rng_mrg.py
theano/sandbox/rng_mrg.py
+307
-22
没有找到文件。
theano/sandbox/rng_mrg.py
浏览文件 @
f8414635
...
@@ -9,7 +9,12 @@ import sys
...
@@ -9,7 +9,12 @@ import sys
import
numpy
import
numpy
from
theano
import
Op
,
Apply
,
shared
,
config
from
theano
import
Op
,
Apply
,
shared
,
config
from
theano.tensor
import
raw_random
,
TensorType
,
as_tensor_variable
,
get_vector_length
,
cast
from
theano.tensor
import
raw_random
,
TensorType
,
as_tensor_variable
,
get_vector_length
,
cast
,
opt
from
theano.compile
import
optdb
from
theano.gof
import
local_optimizer
from
theano.sandbox.cuda.opt
import
register_opt
as
gpu_register_opt
from
theano.sandbox.cuda
import
cuda_enabled
,
CudaNdarrayType
#, gpu_from_host, host_from_gpu, CudaNdarrayType
def
mulmod
(
a
,
b
,
c
,
m
):
def
mulmod
(
a
,
b
,
c
,
m
):
r
=
numpy
.
int32
(
numpy
.
int64
(
a
*
b
+
c
)
%
m
)
r
=
numpy
.
int32
(
numpy
.
int64
(
a
*
b
+
c
)
%
m
)
...
@@ -114,8 +119,9 @@ def mrg_next_value(rstate, new_rstate):
...
@@ -114,8 +119,9 @@ def mrg_next_value(rstate, new_rstate):
else
:
else
:
return
(
x11
-
x21
)
*
NORM
return
(
x11
-
x21
)
*
NORM
class
mrg_uniform
(
Op
):
class
mrg_uniform
_base
(
Op
):
def
__init__
(
self
,
output_type
,
inplace
=
False
):
def
__init__
(
self
,
output_type
,
inplace
=
False
):
Op
.
__init__
(
self
)
self
.
output_type
=
output_type
self
.
output_type
=
output_type
self
.
inplace
=
inplace
self
.
inplace
=
inplace
if
inplace
:
if
inplace
:
...
@@ -129,6 +135,18 @@ class mrg_uniform(Op):
...
@@ -129,6 +135,18 @@ class mrg_uniform(Op):
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
^
hash
(
self
.
output_type
)
^
hash
(
self
.
inplace
)
return
hash
(
type
(
self
))
^
hash
(
self
.
output_type
)
^
hash
(
self
.
inplace
)
def
make_node
(
self
,
rstate
,
size
):
# error checking slightly redundant here, since
# this op should not be called directly.
#
# call through MRG_RandomStreams instead.
return
Apply
(
self
,
[
rstate
,
size
],
[
rstate
.
type
(),
self
.
output_type
()])
class
mrg_uniform
(
mrg_uniform_base
):
#CPU VERSION
@classmethod
@classmethod
def
new
(
cls
,
rstate
,
ndim
,
dtype
,
size
):
def
new
(
cls
,
rstate
,
ndim
,
dtype
,
size
):
v_size
=
as_tensor_variable
(
size
)
v_size
=
as_tensor_variable
(
size
)
...
@@ -137,12 +155,10 @@ class mrg_uniform(Op):
...
@@ -137,12 +155,10 @@ class mrg_uniform(Op):
op
=
cls
(
TensorType
(
dtype
,
(
False
,)
*
ndim
))
op
=
cls
(
TensorType
(
dtype
,
(
False
,)
*
ndim
))
return
op
(
rstate
,
cast
(
v_size
,
'int32'
))
return
op
(
rstate
,
cast
(
v_size
,
'int32'
))
def
make_node
(
self
,
rstate
,
size
):
return
Apply
(
self
,
[
rstate
,
size
],
[
rstate
.
type
(),
self
.
output_type
()])
def
perform
(
self
,
node
,
(
rstate
,
size
),
(
o_rstate
,
o_sample
)):
def
perform
(
self
,
node
,
(
rstate
,
size
),
(
o_rstate
,
o_sample
)):
n_elements
=
1
n_elements
=
1
rstate
=
numpy
.
asarray
(
rstate
)
# bring state from GPU if necessary
if
not
self
.
inplace
:
if
not
self
.
inplace
:
rstate
=
rstate
.
copy
()
rstate
=
rstate
.
copy
()
...
@@ -157,8 +173,8 @@ class mrg_uniform(Op):
...
@@ -157,8 +173,8 @@ class mrg_uniform(Op):
sample
=
mrg_next_value
(
rstate
[
i
%
n_streams
],
rstate
[
i
%
n_streams
])
sample
=
mrg_next_value
(
rstate
[
i
%
n_streams
],
rstate
[
i
%
n_streams
])
rval
[
i
]
=
sample
rval
[
i
]
=
sample
o_rstate
[
0
]
=
rstate
.
copy
()
o_rstate
[
0
]
=
node
.
outputs
[
0
]
.
type
.
filter
(
rstate
)
# send to GPU if necessary
o_sample
[
0
]
=
rval
.
reshape
(
size
)
o_sample
[
0
]
=
node
.
outputs
[
1
]
.
type
.
filter
(
rval
.
reshape
(
size
))
# send to GPU if necessary
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
()
return
()
...
@@ -317,10 +333,223 @@ class mrg_uniform(Op):
...
@@ -317,10 +333,223 @@ class mrg_uniform(Op):
//////// </ code generated by mrg_uniform>
//////// </ code generated by mrg_uniform>
"""
%
locals
()
"""
%
locals
()
class
GPU_mrg_uniform
(
mrg_uniform_base
):
#GPU VERSION
@classmethod
def
new
(
cls
,
rstate
,
ndim
,
dtype
,
size
):
v_size
=
as_tensor_variable
(
size
)
if
ndim
is
None
:
ndim
=
get_vector_length
(
v_size
)
op
=
cls
(
CudaNdarrayType
((
False
,)
*
ndim
))
return
op
(
rstate
,
cast
(
v_size
,
'int32'
))
def
c_support_code_apply
(
self
,
node
,
nodename
):
if
self
.
output_type
.
dtype
==
'float32'
:
otype
=
'float'
NORM
=
'4.6566126e-10f'
#numpy.float32(1.0/(2**31+65))
# this was determined by finding the biggest number such that
# numpy.float32(number * M1) < 1.0
else
:
otype
=
'double'
NORM
=
'4.656612873077392578125e-10'
return
"""
static __global__ void
%(nodename)
s_mrg_uniform(
%(otype)
s*sample_data,
npy_int32*state_data,
const int Nsamples)
{
const npy_int32 i0 = 0;
const npy_int32 i7 = 7;
const npy_int32 i9 = 9;
const npy_int32 i15 = 15;
const npy_int32 i16 = 16;
const npy_int32 i22 = 22;
const npy_int32 i24 = 24;
const npy_int32 M1 = 2147483647; //2^31 - 1
const npy_int32 M2 = 2147462579; //2^31 - 21069
const npy_int32 MASK12 = 511; //2^9 - 1
const npy_int32 MASK13 = 16777215; //2^24 - 1
const npy_int32 MASK2 = 65535; //2^16 - 1
const npy_int32 MULT2 = 21069;
const unsigned int numThreads = blockDim.x * gridDim.x;
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
npy_int32 y1, y2, x11, x12, x13, x21, x22, x23;
x11 = state_data[idx*6+0];
x12 = state_data[idx*6+1];
x13 = state_data[idx*6+2];
x21 = state_data[idx*6+3];
x22 = state_data[idx*6+4];
x23 = state_data[idx*6+5];
for (int i = idx; i < Nsamples; i += numThreads)
{
y1 = ((x12 & MASK12) << i22) + (x12 >> i9) + ((x13 & MASK13) << i7) + (x13 >> i24);
if ((y1 < 0 || y1 >= M1)) //must also check overflow
y1 -= M1;
y1 += x13;
if ((y1 < 0 or y1 >= M1))
y1 -= M1;
x13 = x12;
x12 = x11;
x11 = y1;
y1 = ((x21 & MASK2) << i15) + (MULT2 * (x21 >> i16));
if (y1 < 0 || y1 >= M2)
y1 -= M2;
y2 = ((x23 & MASK2) << i15) + (MULT2 * (x23 >> i16));
if (y2 < 0 || y2 >= M2)
y2 -= M2;
y2 += x23;
if (y2 < 0 || y2 >= M2)
y2 -= M2;
y2 += y1;
if (y2 < 0 or y2 >= M2)
y2 -= M2;
x23 = x22;
x22 = x21;
x21 = y2;
if (x11 <= x21) {
sample_data[i] = (x11 - x21 + M1) *
%(NORM)
s;
}
else
{
sample_data[i] = (x11 - x21) *
%(NORM)
s;
}
}
state_data[idx*6+0]= x11;
state_data[idx*6+1]= x12;
state_data[idx*6+2]= x13;
state_data[idx*6+3]= x21;
state_data[idx*6+4]= x22;
state_data[idx*6+5]= x23;
}
"""
%
locals
()
def
c_code_cache_version
(
self
):
return
()
def
c_code
(
self
,
node
,
nodename
,
(
rstate
,
size
),
(
o_rstate
,
o_sample
),
sub
):
inplace
=
int
(
self
.
inplace
)
ndim
=
self
.
output_type
.
ndim
o_type_num
=
numpy
.
asarray
(
0
,
dtype
=
self
.
output_type
.
dtype
)
.
dtype
.
num
fail
=
sub
[
'fail'
]
if
self
.
output_type
.
dtype
==
'float32'
:
otype
=
'float'
else
:
otype
=
'double'
SYNC
=
"CNDA_THREAD_SYNC"
;
return
"""
//////// <code generated by mrg_uniform>
int odims[
%(ndim)
s];
int n_elements = 1;
unsigned int n_streams;
int must_alloc_sample = ((NULL ==
%(o_sample)
s)
|| !CudaNdarray_Check(py_
%(o_sample)
s)
|| (
%(o_sample)
s->nd !=
%(ndim)
s));
if (
%(size)
s->nd != 1)
{
PyErr_SetString(PyExc_ValueError, "size must be vector");
%(fail)
s
}
if (
%(size)
s->dimensions[0] !=
%(ndim)
s)
{
PyErr_Format(PyExc_ValueError, "size must have length
%%
i",
%(ndim)
s);
%(fail)
s
}
if (
%(size)
s->descr->type_num != PyArray_INT32)
{
PyErr_SetString(PyExc_ValueError, "size must be int32");
%(fail)
s
}
for (int i = 0; i <
%(ndim)
s; ++i)
{
odims[i] = ((npy_int32*)(
%(size)
s->data +
%(size)
s->strides[0] * i))[0];
n_elements *= odims[i];
must_alloc_sample = (must_alloc_sample
|| CudaNdarray_HOST_DIMS(
%(o_sample)
s)[i] != odims[i]);
}
if (must_alloc_sample)
{
Py_XDECREF(
%(o_sample)
s);
%(o_sample)
s = (CudaNdarray*)CudaNdarray_NewDims(
%(ndim)
s, odims);
if(!
%(o_sample)
s)
{
%(fail)
s;
}
}
if (!CudaNdarray_Check(py_
%(rstate)
s))
{
PyErr_Format(PyExc_ValueError, "rstate must be cudandarray");
%(fail)
s;
}
Py_XDECREF(
%(o_rstate)
s);
if (
%(inplace)
s)
{
Py_INCREF(
%(rstate)
s);
%(o_rstate)
s =
%(rstate)
s;
}
else
{
%(o_rstate)
s = (CudaNdarray*)CudaNdarray_Copy(
%(rstate)
s);
}
if (
%(o_rstate)
s->nd != 1)
{
PyErr_SetString(PyExc_ValueError, "rstate must be vector");
%(fail)
s;
}
if (CudaNdarray_HOST_DIMS(
%(o_rstate)
s)[0]
%% 6
)
{
PyErr_Format(PyExc_ValueError, "rstate len must be multiple of 6");
%(fail)
s;
}
n_streams = CudaNdarray_HOST_DIMS(
%(o_rstate)
s)[0]/6;
{
unsigned int threads_per_block = std::min(n_streams, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
unsigned int n_blocks = std::min(ceil_intdiv(n_streams, threads_per_block), (unsigned int)NUM_VECTOR_OP_BLOCKS);
if (threads_per_block * n_blocks < n_streams)
{
fprintf(stderr, "WARNING: unused streams above
%%
i (Tune GPU_mrg get_n_streams)
\\
n", threads_per_block * n_blocks );
}
%(nodename)
s_mrg_uniform<<<n_blocks,threads_per_block>>>(
CudaNdarray_DEV_DATA(
%(o_sample)
s),
(npy_int32*)CudaNdarray_DEV_DATA(
%(o_rstate)
s),
n_elements);
}
%(SYNC)
s;
{
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s.
\\
n", "mrg_uniform", cudaGetErrorString(err));
%(fail)
s;
}
}
//////// </ code generated by mrg_uniform>
"""
%
locals
()
class
MRG_RandomStreams
(
object
):
class
MRG_RandomStreams
(
object
):
"""Module component with similar interface to numpy.random (numpy.random.RandomState)"""
"""Module component with similar interface to numpy.random (numpy.random.RandomState)"""
def
__init__
(
self
,
seed
=
None
):
def
__init__
(
self
,
seed
=
None
,
use_cuda
=
None
):
"""
"""
:type seed: None or int
:type seed: None or int
...
@@ -329,6 +558,10 @@ class MRG_RandomStreams(object):
...
@@ -329,6 +558,10 @@ class MRG_RandomStreams(object):
"""
"""
super
(
MRG_RandomStreams
,
self
)
.
__init__
()
super
(
MRG_RandomStreams
,
self
)
.
__init__
()
self
.
rstate
=
numpy
.
asarray
([
12345
]
*
6
,
dtype
=
'int32'
)
self
.
rstate
=
numpy
.
asarray
([
12345
]
*
6
,
dtype
=
'int32'
)
if
use_cuda
is
None
:
self
.
use_cuda
=
cuda_enabled
else
:
self
.
use_cuda
=
use_cuda
def
inc_rstate
(
self
):
def
inc_rstate
(
self
):
"""Update self.rstate to be skipped 2^134 steps forward to the next stream start"""
"""Update self.rstate to be skipped 2^134 steps forward to the next stream start"""
...
@@ -361,7 +594,6 @@ class MRG_RandomStreams(object):
...
@@ -361,7 +594,6 @@ class MRG_RandomStreams(object):
node_rstate
.
default_update
=
new_rstate
node_rstate
.
default_update
=
new_rstate
return
sample
return
sample
def
uniform
(
self
,
size
=
None
,
low
=
0.0
,
high
=
1.0
,
ndim
=
None
,
dtype
=
config
.
floatX
):
def
uniform
(
self
,
size
=
None
,
low
=
0.0
,
high
=
1.0
,
ndim
=
None
,
dtype
=
config
.
floatX
):
"""
"""
Sample a tensor of given size whose element from a uniform
Sample a tensor of given size whose element from a uniform
...
@@ -371,6 +603,24 @@ class MRG_RandomStreams(object):
...
@@ -371,6 +603,24 @@ class MRG_RandomStreams(object):
ndim may be a plain integer to supplement the missing
ndim may be a plain integer to supplement the missing
information.
information.
"""
"""
if
self
.
use_cuda
and
dtype
==
'float32'
:
rstates
=
self
.
get_substream_rstates
(
self
.
n_streams
(
size
))
rstates
=
rstates
.
flatten
()
# HACK - we use fact that int32 and float32 have same size to
# sneak ints into the CudaNdarray type.
# these *SHOULD NEVER BE USED AS FLOATS*
tmp_float_buf
=
numpy
.
frombuffer
(
rstates
.
data
,
dtype
=
'float32'
)
assert
tmp_float_buf
.
shape
==
rstates
.
shape
assert
tmp_float_buf
.
data
[:
24
]
==
rstates
.
data
[:
24
]
node_rstate
=
shared
(
tmp_float_buf
)
# transfer to device
assert
isinstance
(
node_rstate
.
type
,
CudaNdarrayType
)
# we can't use the normal mrg_uniform constructor + later optimization
# because of the tmp_float_buf hack above. There is
# currently no Theano node that will do a frombuffer reinterpretation.
u
=
self
.
pretty_return
(
node_rstate
,
*
GPU_mrg_uniform
.
new
(
node_rstate
,
ndim
,
dtype
,
size
))
else
:
node_rstate
=
shared
(
self
.
get_substream_rstates
(
self
.
n_streams
(
size
)))
node_rstate
=
shared
(
self
.
get_substream_rstates
(
self
.
n_streams
(
size
)))
u
=
self
.
pretty_return
(
node_rstate
,
u
=
self
.
pretty_return
(
node_rstate
,
*
mrg_uniform
.
new
(
node_rstate
,
ndim
,
dtype
,
size
))
*
mrg_uniform
.
new
(
node_rstate
,
ndim
,
dtype
,
size
))
...
@@ -380,6 +630,17 @@ class MRG_RandomStreams(object):
...
@@ -380,6 +630,17 @@ class MRG_RandomStreams(object):
raise
NotImplementedError
(
'Increase the size to match the broadcasting pattern of `low` and `high` arguments'
)
raise
NotImplementedError
(
'Increase the size to match the broadcasting pattern of `low` and `high` arguments'
)
return
r
return
r
@local_optimizer
([
None
])
def
mrg_random_make_inplace
(
node
):
op
=
node
.
op
if
isinstance
(
op
,
mrg_uniform
)
and
not
op
.
inplace
:
# op might be gpu version
new_op
=
op
.
__class__
(
op
.
output_type
,
inplace
=
True
)
return
new_op
.
make_node
(
*
node
.
inputs
)
.
outputs
return
False
optdb
.
register
(
'random_make_inplace_mrg'
,
opt
.
in2out
(
mrg_random_make_inplace
,
ignore_newtrees
=
True
),
99
,
'fast_run'
,
'inplace'
)
#
#
#
#
#
#
...
@@ -391,37 +652,61 @@ import theano
...
@@ -391,37 +652,61 @@ import theano
def
test_rng0
():
def
test_rng0
():
def
basictest
(
f
,
steps
,
prefix
=
""
):
def
basictest
(
f
,
steps
,
prefix
=
""
):
dt
=
0.0
for
i
in
xrange
(
steps
):
t0
=
time
.
time
()
t0
=
time
.
time
()
l
=
[
f
()
for
i
in
xrange
(
steps
)]
ival
=
f
()
tt
=
time
.
time
()
dt
+=
time
.
time
()
-
t0
ival
=
numpy
.
asarray
(
ival
)
if
i
==
0
:
mean
=
numpy
.
array
(
ival
,
copy
=
True
)
else
:
alpha
=
1.0
/
(
1
+
i
)
mean
=
alpha
*
ival
+
(
1
-
alpha
)
*
mean
print
prefix
,
'mean'
,
numpy
.
mean
(
mean
)
assert
abs
(
numpy
.
mean
(
mean
)
-
0.5
)
<
.
01
,
'bad mean?'
print
prefix
,
'time'
,
dt
print
prefix
,
'elements'
,
steps
*
sample_size
[
0
]
*
sample_size
[
1
]
print
prefix
,
'samples/sec'
,
steps
*
sample_size
[
0
]
*
sample_size
[
1
]
/
dt
if
0
:
mean
,
std
,
min
,
max
=
numpy
.
mean
(
l
),
numpy
.
std
(
l
),
numpy
.
min
(
l
),
numpy
.
max
(
l
)
mean
,
std
,
min
,
max
=
numpy
.
mean
(
l
),
numpy
.
std
(
l
),
numpy
.
min
(
l
),
numpy
.
max
(
l
)
print
prefix
,
'mean'
,
mean
print
prefix
,
'mean'
,
mean
print
prefix
,
'std'
,
std
print
prefix
,
'std'
,
std
print
prefix
,
'min'
,
repr
(
min
)
print
prefix
,
'min'
,
repr
(
min
)
print
prefix
,
'max'
,
repr
(
max
)
print
prefix
,
'max'
,
repr
(
max
)
print
prefix
,
'samples/sec'
,
steps
*
sample_size
[
0
]
*
sample_size
[
1
]
/
(
tt
-
t0
)
assert
max
<
1.0
assert
max
<
1.0
assert
min
>=
0.0
assert
min
>=
0.0
assert
abs
(
mean
-
0.5
)
<
.
01
,
'bad mean?'
assert
abs
(
mean
-
0.5
)
<
.
01
,
'bad mean?'
sample_size
=
(
1000
,
100
)
R
=
MRG_RandomStreams
(
234
)
print
''
print
'ON CPU:'
sample_size
=
(
200
,
20
)
R
=
MRG_RandomStreams
(
234
,
use_cuda
=
False
)
u
=
R
.
uniform
(
size
=
sample_size
)
u
=
R
.
uniform
(
size
=
sample_size
)
print
"U dtype"
,
u
.
dtype
f
=
theano
.
function
([],
u
)
f
=
theano
.
function
([],
u
)
theano
.
printing
.
debugprint
(
f
)
print
'random?[:10]
\n
'
,
f
()[
0
,
0
:
10
]
basictest
(
f
,
1000
,
prefix
=
'mrg '
)
print
'random?'
,
f
()[
0
]
print
''
print
'random?'
,
f
()[
0
]
print
'ON GPU:'
R
=
MRG_RandomStreams
(
234
,
use_cuda
=
True
)
u
=
R
.
uniform
(
size
=
sample_size
)
assert
u
.
dtype
==
'float32'
#well, it's really that this test w GPU doesn't make sense otw
f
=
theano
.
function
([],
theano
.
Out
(
theano
.
sandbox
.
cuda
.
basic_ops
.
gpu_from_host
(
u
),
borrow
=
True
))
theano
.
printing
.
debugprint
(
f
)
print
'random?[:10]
\n
'
,
numpy
.
asarray
(
f
())[
0
,
0
:
10
]
basictest
(
f
,
1000
,
prefix
=
'mrg '
)
basictest
(
f
,
1000
,
prefix
=
'mrg '
)
print
''
print
'ON CPU w NUMPY:'
RR
=
theano
.
tensor
.
shared_randomstreams
.
RandomStreams
(
234
)
RR
=
theano
.
tensor
.
shared_randomstreams
.
RandomStreams
(
234
)
uu
=
RR
.
uniform
(
size
=
sample_size
)
uu
=
RR
.
uniform
(
size
=
sample_size
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论