Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c7c2a019
提交
c7c2a019
authored
1月 13, 2014
作者:
Arnaud Bergeron
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add first version of MRG for gpuarray (with tests).
上级
1b4fd86f
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
334 行增加
和
0 行删除
+334
-0
rng_mrg.py
theano/sandbox/rng_mrg.py
+231
-0
test_rng_mrg.py
theano/sandbox/test_rng_mrg.py
+103
-0
没有找到文件。
theano/sandbox/rng_mrg.py
浏览文件 @
c7c2a019
...
...
@@ -608,6 +608,237 @@ class GPU_mrg_uniform(mrg_uniform_base, GpuOp):
return
(
7
,)
class
GPUA_mrg_uniform
(
mrg_uniform_base
):
#GpuArray version
@classmethod
def
new
(
cls
,
rstate
,
ndim
,
dtype
,
size
):
v_size
=
as_tensor_variable
(
size
)
if
ndim
is
None
:
ndim
=
get_vector_length
(
v_size
)
op
=
cls
(
GpuArrayType
(
dtype
,
(
False
,)
*
ndim
))
return
op
(
rstate
,
cast
(
v_size
,
'int32'
))
def
c_headers
(
self
):
return
[
"<compyte/ext_cuda.h>"
]
def
c_init_code
(
self
):
return
[
"setup_ext_cuda();"
]
def
c_support_code_apply
(
self
,
node
,
nodename
):
if
self
.
output_type
.
dtype
==
'float32'
:
otype
=
'float'
NORM
=
'4.6566126e-10f'
# numpy.float32(1.0/(2**31+65))
# this was determined by finding the biggest number such that
# numpy.float32(number * M1) < 1.0
else
:
otype
=
'double'
NORM
=
'4.656612873077392578125e-10'
return
"""
static int
%(nodename)
s_printed_warning = 0;
static __global__ void
%(nodename)
s_mrg_uniform(
%(otype)
s*sample_data,
npy_int32*state_data,
const int Nsamples,
const int Nstreams_used)
{
const npy_int32 i0 = 0;
const npy_int32 i7 = 7;
const npy_int32 i9 = 9;
const npy_int32 i15 = 15;
const npy_int32 i16 = 16;
const npy_int32 i22 = 22;
const npy_int32 i24 = 24;
const npy_int32 M1 = 2147483647; //2^31 - 1
const npy_int32 M2 = 2147462579; //2^31 - 21069
const npy_int32 MASK12 = 511; //2^9 - 1
const npy_int32 MASK13 = 16777215; //2^24 - 1
const npy_int32 MASK2 = 65535; //2^16 - 1
const npy_int32 MULT2 = 21069;
const unsigned int numThreads = blockDim.x * gridDim.x;
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
npy_int32 y1, y2, x11, x12, x13, x21, x22, x23;
if (idx < Nstreams_used)
{
x11 = state_data[idx*6+0];
x12 = state_data[idx*6+1];
x13 = state_data[idx*6+2];
x21 = state_data[idx*6+3];
x22 = state_data[idx*6+4];
x23 = state_data[idx*6+5];
for (int i = idx; i < Nsamples; i += Nstreams_used)
{
y1 = ((x12 & MASK12) << i22) + (x12 >> i9) + ((x13 & MASK13) << i7) + (x13 >> i24);
y1 -= (y1 < 0 || y1 >= M1) ? M1 : 0;
y1 += x13;
y1 -= (y1 < 0 || y1 >= M1) ? M1 : 0;
x13 = x12;
x12 = x11;
x11 = y1;
y1 = ((x21 & MASK2) << i15) + (MULT2 * (x21 >> i16));
y1 -= (y1 < 0 || y1 >= M2) ? M2 : 0;
y2 = ((x23 & MASK2) << i15) + (MULT2 * (x23 >> i16));
y2 -= (y2 < 0 || y2 >= M2) ? M2 : 0;
y2 += x23;
y2 -= (y2 < 0 || y2 >= M2) ? M2 : 0;
y2 += y1;
y2 -= (y2 < 0 || y2 >= M2) ? M2 : 0;
x23 = x22;
x22 = x21;
x21 = y2;
if (x11 <= x21) {
sample_data[i] = (x11 - x21 + M1) *
%(NORM)
s;
}
else
{
sample_data[i] = (x11 - x21) *
%(NORM)
s;
}
}
state_data[idx*6+0]= x11;
state_data[idx*6+1]= x12;
state_data[idx*6+2]= x13;
state_data[idx*6+3]= x21;
state_data[idx*6+4]= x22;
state_data[idx*6+5]= x23;
}
}
"""
%
locals
()
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
rstate
,
size
=
inp
o_rstate
,
o_sample
=
out
inplace
=
int
(
self
.
inplace
)
ndim
=
self
.
output_type
.
ndim
o_type_num
=
numpy
.
asarray
(
0
,
dtype
=
self
.
output_type
.
dtype
)
.
dtype
.
num
fail
=
sub
[
'fail'
]
if
self
.
output_type
.
dtype
==
'float32'
:
otype
=
'float'
else
:
otype
=
'double'
return
"""
//////// <code generated by mrg_uniform>
size_t odims[
%(ndim)
s];
unsigned int n_elements = 1;
unsigned int n_streams, n_streams_used_in_this_call;
int must_alloc_sample = ((NULL ==
%(o_sample)
s)
|| !pygpu_GpuArray_Check(py_
%(o_sample)
s)
|| !(
%(o_sample)
s->ga.flags & GA_C_CONTIGUOUS)
|| (PyGpuArray_NDIM(
%(o_sample)
s) !=
%(ndim)
s));
if (PyArray_NDIM(
%(size)
s) != 1)
{
PyErr_SetString(PyExc_ValueError, "size must be vector");
%(fail)
s
}
if (PyArray_DIMS(
%(size)
s)[0] !=
%(ndim)
s)
{
PyErr_Format(PyExc_ValueError, "size must have length
%%
i (not
%%
i)",
%(ndim)
s, PyArray_DIMS(
%(size)
s)[0]);
%(fail)
s
}
if (PyArray_DESCR(
%(size)
s)->type_num != NPY_INT32)
{
PyErr_SetString(PyExc_ValueError, "size must be int32");
%(fail)
s
}
for (int i = 0; i <
%(ndim)
s; ++i)
{
odims[i] = ((npy_int32*)(PyArray_BYTES(
%(size)
s) + PyArray_STRIDES(
%(size)
s)[0] * i))[0];
n_elements *= odims[i];
must_alloc_sample = (must_alloc_sample
|| PyGpuArray_DIMS(
%(o_sample)
s)[i] != odims[i]);
}
if (must_alloc_sample)
{
Py_XDECREF(
%(o_sample)
s);
%(o_sample)
s = pygpu_empty(
%(ndim)
s, odims, GA_FLOAT, GA_C_ORDER,
pygpu_default_context(), Py_None);
if(!
%(o_sample)
s)
{
%(fail)
s;
}
}
if (!pygpu_GpuArray_Check(py_
%(rstate)
s))
{
PyErr_Format(PyExc_ValueError, "rstate must be gpuarray");
%(fail)
s;
}
Py_XDECREF(
%(o_rstate)
s);
if (
%(inplace)
s)
{
Py_INCREF(
%(rstate)
s);
%(o_rstate)
s =
%(rstate)
s;
}
else
{
%(o_rstate)
s = pygpu_copy(
%(rstate)
s);
}
if (PyGpuArray_NDIM(
%(o_rstate)
s) != 1)
{
PyErr_SetString(PyExc_ValueError, "rstate must be vector");
%(fail)
s;
}
if (PyGpuArray_DIMS(
%(o_rstate)
s)[0]
%% 6
)
{
PyErr_Format(PyExc_ValueError, "rstate len must be multiple of 6");
%(fail)
s;
}
n_streams = PyGpuArray_DIMS(
%(o_rstate)
s)[0]/6;
n_streams_used_in_this_call = std::min(n_streams, n_elements);
{
unsigned int threads_per_block = std::min((unsigned int)n_streams_used_in_this_call, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
unsigned int n_blocks = std::min(ceil_intdiv((unsigned int)n_streams_used_in_this_call, threads_per_block), (unsigned int)NUM_VECTOR_OP_BLOCKS);
if (threads_per_block * n_blocks < n_streams)
{
if (!
%(nodename)
s_printed_warning)
fprintf(stderr, "WARNING: unused streams above
%%
i (Tune GPU_mrg get_n_streams)
\\
n", threads_per_block * n_blocks );
%(nodename)
s_printed_warning = 1;
}
cuda_enter(pygpu_default_context()->ctx);
%(nodename)
s_mrg_uniform<<<n_blocks,threads_per_block>>>(
cuda_get_ptr(
%(o_sample)
s),
cuda_get_ptr(
%(o_rstate)
s),
n_elements, n_streams_used_in_this_call);
/* We need the full sync since we just modified libgpu
objects without informing it */
cudaDeviceSynchronize();
}
cudaError_t err = cudaGetLastError();
cuda_exit(pygpu_default_context()->ctx);
if (cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s.
\\
n", "mrg_uniform", cudaGetErrorString(err));
%(fail)
s;
}
//////// </ code generated by mrg_uniform>
"""
%
locals
()
def
c_code_cache_version
(
self
):
return
(
0
,)
def
guess_n_streams
(
size
,
warn
=
True
):
"""
Return a guess at a good number of streams.
...
...
theano/sandbox/test_rng_mrg.py
浏览文件 @
c7c2a019
...
...
@@ -302,6 +302,109 @@ def test_consistency_GPU_parallel():
assert
(
numpy
.
allclose
(
samples
,
java_samples
))
def
test_consistency_GPUA_serial
():
'''Verify that the random numbers generated by GPUA_mrg_uniform, serially,
are the same as the reference (Java) implementation by L'Ecuyer et al.
'''
from
theano.sandbox.gpuarray.tests.test_basic_ops
import
\
mode_with_gpu
as
mode
from
theano.sandbox.gpuarray.type
import
gpuarray_shared_constructor
seed
=
12345
n_samples
=
5
n_streams
=
12
n_substreams
=
7
samples
=
[]
curr_rstate
=
numpy
.
array
([
seed
]
*
6
,
dtype
=
'int32'
)
for
i
in
range
(
n_streams
):
stream_rstate
=
curr_rstate
.
copy
()
for
j
in
range
(
n_substreams
):
substream_rstate
=
numpy
.
array
(
stream_rstate
.
copy
(),
dtype
=
'int32'
)
# Transfer to device
rstate
=
gpuarray_shared_constructor
(
substream_rstate
)
new_rstate
,
sample
=
rng_mrg
.
GPUA_mrg_uniform
.
new
(
rstate
,
ndim
=
None
,
dtype
=
'float32'
,
size
=
(
1
,))
rstate
.
default_update
=
new_rstate
# Not really necessary, just mimicking
# rng_mrg.MRG_RandomStreams' behavior
sample
.
rstate
=
rstate
sample
.
update
=
(
rstate
,
new_rstate
)
# We need the sample back in the main memory
cpu_sample
=
tensor
.
as_tensor_variable
(
sample
)
f
=
theano
.
function
([],
cpu_sample
,
mode
=
mode
)
for
k
in
range
(
n_samples
):
s
=
f
()
samples
.
append
(
s
)
# next substream
stream_rstate
=
rng_mrg
.
ff_2p72
(
stream_rstate
)
# next stream
curr_rstate
=
rng_mrg
.
ff_2p134
(
curr_rstate
)
samples
=
numpy
.
array
(
samples
)
.
flatten
()
assert
(
numpy
.
allclose
(
samples
,
java_samples
))
def
test_consistency_GPUA_parallel
():
'''Verify that the random numbers generated by GPUA_mrg_uniform, in
parallel, are the same as the reference (Java) implementation by
L'Ecuyer et al.
'''
from
theano.sandbox.gpuarray.tests.test_basic_ops
import
\
mode_with_gpu
as
mode
from
theano.sandbox.gpuarray.type
import
gpuarray_shared_constructor
seed
=
12345
n_samples
=
5
n_streams
=
12
n_substreams
=
7
# 7 samples will be drawn in parallel
samples
=
[]
curr_rstate
=
numpy
.
array
([
seed
]
*
6
,
dtype
=
'int32'
)
for
i
in
range
(
n_streams
):
stream_samples
=
[]
rstate
=
[
curr_rstate
.
copy
()]
for
j
in
range
(
1
,
n_substreams
):
rstate
.
append
(
rng_mrg
.
ff_2p72
(
rstate
[
-
1
]))
rstate
=
numpy
.
asarray
(
rstate
)
.
flatten
()
rstate
=
gpuarray_shared_constructor
(
rstate
)
new_rstate
,
sample
=
rng_mrg
.
GPUA_mrg_uniform
.
new
(
rstate
,
ndim
=
None
,
dtype
=
'float32'
,
size
=
(
n_substreams
,))
rstate
.
default_update
=
new_rstate
# Not really necessary, just mimicking
# rng_mrg.MRG_RandomStreams' behavior
sample
.
rstate
=
rstate
sample
.
update
=
(
rstate
,
new_rstate
)
# We need the sample back in the main memory
cpu_sample
=
tensor
.
as_tensor_variable
(
sample
)
f
=
theano
.
function
([],
cpu_sample
,
mode
=
mode
)
for
k
in
range
(
n_samples
):
s
=
f
()
stream_samples
.
append
(
s
)
samples
.
append
(
numpy
.
array
(
stream_samples
)
.
T
.
flatten
())
# next stream
curr_rstate
=
rng_mrg
.
ff_2p134
(
curr_rstate
)
samples
=
numpy
.
array
(
samples
)
.
flatten
()
assert
(
numpy
.
allclose
(
samples
,
java_samples
))
def
basictest
(
f
,
steps
,
sample_size
,
prefix
=
""
,
allow_01
=
False
,
inputs
=
None
,
target_avg
=
0.5
,
target_std
=
None
,
mean_rtol
=
0.01
,
std_tol
=
0.01
):
if
inputs
is
None
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论