提交 38ab51a1 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the GPUA version of MRG to use the same rstate shape as the CPU one.

上级 455705e7
......@@ -952,17 +952,25 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
}
}
if (PyGpuArray_NDIM(%(o_rstate)s) != 1)
if (PyGpuArray_NDIM(%(o_rstate)s) != 2)
{
PyErr_SetString(PyExc_ValueError, "rstate must be vector");
%(fail)s;
PyErr_SetString(PyExc_ValueError, "rstate must be a matrix");
%(fail)s
}
if (PyGpuArray_DIMS(%(o_rstate)s)[0] %% 6)
if (PyGpuArray_DIMS(%(o_rstate)s)[1] != 6)
{
PyErr_Format(PyExc_ValueError, "rstate len must be multiple of 6");
%(fail)s;
PyErr_Format(PyExc_ValueError, "rstate must have 6 columns");
%(fail)s
}
if (%(o_rstate)s->ga.typecode != GA_INT) {
PyErr_Format(PyExc_ValueError, "rstate must be int32");
%(fail)s
}
if (!GpuArray_CHKFLAGS(&%(o_rstate)s->ga, GA_C_CONTIGUOUS)) {
PyErr_Format(PyExc_ValueError, "rstate must be C contiguous");
%(fail)s
}
n_streams = PyGpuArray_DIMS(%(o_rstate)s)[0]/6;
n_streams = PyGpuArray_DIMS(%(o_rstate)s)[0];
if (n_streams > n_elements)
n_streams = n_elements;
......
......@@ -325,7 +325,8 @@ def test_consistency_GPUA_serial():
for i in range(n_streams):
stream_rstate = curr_rstate.copy()
for j in range(n_substreams):
substream_rstate = numpy.array(stream_rstate.copy(), dtype='int32')
substream_rstate = numpy.array([stream_rstate.copy()],
dtype='int32')
# Transfer to device
rstate = gpuarray_shared_constructor(substream_rstate)
......@@ -380,7 +381,7 @@ def test_consistency_GPUA_parallel():
rstate = [curr_rstate.copy()]
for j in range(1, n_substreams):
rstate.append(rng_mrg.ff_2p72(rstate[-1]))
rstate = numpy.asarray(rstate).flatten()
rstate = numpy.asarray(rstate)
rstate = gpuarray_shared_constructor(rstate)
new_rstate, sample = rng_mrg.GPUA_mrg_uniform.new(rstate, ndim=None,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论