Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6c955ecf
提交
6c955ecf
authored
8月 23, 2011
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
New version of GpuAdvancedSubtensor1 with gpu code when input have up to 3…
New version of GpuAdvancedSubtensor1 with gpu code when input have up to 3 dimensions or is c_contiguous.
上级
6005ac2b
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
443 行增加
和
9 行删除
+443
-9
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+26
-0
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+366
-8
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+6
-0
test_basic_ops.py
theano/sandbox/cuda/tests/test_basic_ops.py
+45
-1
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
6c955ecf
...
@@ -1891,6 +1891,8 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
...
@@ -1891,6 +1891,8 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
"""
"""
Implement AdvancedSubtensor1 on the gpu.
Implement AdvancedSubtensor1 on the gpu.
"""
"""
assert_fast
=
None
def
make_node
(
self
,
x
,
ilist
):
def
make_node
(
self
,
x
,
ilist
):
x_
=
as_cuda_ndarray_variable
(
x
)
x_
=
as_cuda_ndarray_variable
(
x
)
ilist_
=
tensor
.
as_tensor_variable
(
ilist
)
ilist_
=
tensor
.
as_tensor_variable
(
ilist
)
...
@@ -1908,8 +1910,32 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
...
@@ -1908,8 +1910,32 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
#super(GpuAdvancedSubtensor1, self).perform(node, inp, out_)
#super(GpuAdvancedSubtensor1, self).perform(node, inp, out_)
x
,
idx
=
inp
x
,
idx
=
inp
out
,
=
out_
out
,
=
out_
new_method
=
True
#TODO: if more then 3 dims, reshape the inputs if it is contiguous.
x_orig
=
x
if
x
.
ndim
>
3
and
x
.
is_c_contiguous
():
x
=
x
.
reshape
((
x
.
shape
[
0
],
numpy
.
prod
(
x
.
shape
[
1
:])))
if
x
.
ndim
<=
3
:
if
self
.
assert_fast
is
not
None
:
assert
self
.
assert_fast
==
True
,
(
"GpuAdvancedSubtensor1 used the fast version"
)
# Support x with dimensions 1,2,3 only.
o
=
x
.
take
(
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
(
idx
.
astype
(
"float32"
)),
0
,
out_
[
0
][
0
])
# idx, axis, return[, clipmode]
if
x
is
not
x_orig
:
o
=
o
.
reshape
((
len
(
idx
),)
+
x_orig
.
shape
[
1
:])
out
[
0
]
=
o
else
:
if
self
.
assert_fast
is
not
None
:
assert
self
.
assert_fast
==
False
,
(
"GpuAdvancedSubtensor1 didn't used the fast version"
)
if
(
out_
[
0
][
0
]
is
None
or
out_
[
0
][
0
]
.
shape
!=
(
len
(
idx
),)
+
x
.
shape
[
1
:]):
o
=
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
.
zeros
((
len
(
idx
),)
+
o
=
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
.
zeros
((
len
(
idx
),)
+
x
.
shape
[
1
:])
x
.
shape
[
1
:])
else
:
o
=
out_
[
0
][
0
]
for
(
j
,
i
)
in
enumerate
(
idx
):
for
(
j
,
i
)
in
enumerate
(
idx
):
o
[
j
]
=
x
[
i
]
o
[
j
]
=
x
[
i
]
out
[
0
]
=
o
out
[
0
]
=
o
...
...
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
6c955ecf
...
@@ -682,6 +682,369 @@ PyObject * CudaNdarray_View(const CudaNdarray * self)
...
@@ -682,6 +682,369 @@ PyObject * CudaNdarray_View(const CudaNdarray * self)
return
(
PyObject
*
)
rval
;
return
(
PyObject
*
)
rval
;
}
}
enum
operator_t
{
IADD
=
0
,
IDIV
,
CPY
,
N_ELEMWISE_OPS
// This is to know the number of operation
};
/*
* d0,... are the output dims
* indices are a list of index to operate on
* They are int32 viewed as float32.
* a is the output
* b is the input
* dB0, the source leading dimensions size
*/
template
<
int
operator_num
>
__global__
void
k_take_3
(
const
int
d0
,
const
int
d1
,
const
int
d2
,
const
float
*
indices
,
float
*
a
,
const
int
sA0
,
const
int
sA1
,
const
int
sA2
,
const
float
*
b
,
const
int
dB0
,
const
int
sB0
,
const
int
sB1
,
const
int
sB2
,
int
*
err
){
for
(
int
i0
=
blockIdx
.
x
;
i0
<
d0
;
i0
+=
gridDim
.
x
){
int
idx
=
(
int
)
indices
[
i0
];
if
(
idx
<
0
)
idx
+=
dB0
;
// To allow negative indexing.
if
((
idx
<
0
)
||
(
idx
>=
dB0
))
*
err
=
0xFFFF
;
for
(
int
i1
=
threadIdx
.
x
;
i1
<
d1
;
i1
+=
blockDim
.
x
){
for
(
int
i2
=
threadIdx
.
y
;
i2
<
d2
;
i2
+=
blockDim
.
y
){
int
a_idx
=
i0
*
sA0
+
i1
*
sA1
+
i2
*
sA2
;
int
b_idx
=
idx
*
sB0
+
i1
*
sB1
+
i2
*
sB2
;
a
[
a_idx
]
=
b
[
b_idx
];
}
}
}
}
// Pointor to 1 int on the device
// Used in CudaNdarray_TakeFrom to tell that there is an out of bound error
// When it exist, it should always be 0
// So if there is an error, we must reset it to 0 BEFORE we raise the error
// This prevent us from setting it to 0 before each use
static
int
*
err_var
=
NULL
;
//PyObject* PyArray_TakeFrom(PyArrayObject* self, PyObject* indices, int axis, PyArrayObject* ret, NPY_CLIPMODE clipmode)
//TODO: support other clip mode then raise(clip, wrap)
//TODO: what if the indices take more then 32 bits?
//self is the input that we copy data from.
PyObject
*
CudaNdarray_TakeFrom
(
CudaNdarray
*
self
,
PyObject
*
args
){
int
verbose
=
0
;
PyObject
*
indices_obj
=
NULL
;
//int axis; Default None, that mean the flattened array.
PyObject
*
axis_obj
=
Py_None
;
PyObject
*
out_obj
=
Py_None
;
PyObject
*
clipmode_obj
=
NULL
;
if
(
!
PyArg_ParseTuple
(
args
,
"O|OOO"
,
&
indices_obj
,
&
axis_obj
,
&
out_obj
,
&
clipmode_obj
))
return
NULL
;
//Check argument indices
//TODO: if not a numpy.ndarray, convert to numpy.ndarray
//TODO: If a CudaNdarray, accept it and suppose the data is int32? is float32 number of int?
//TODO: Support ndarray of other dtype then int32
//TODO: support list of indices that are not c_contiguous
CudaNdarray
*
indices
=
NULL
;
if
(
CudaNdarray_Check
(
indices_obj
))
{
if
(
verbose
)
printf
(
"cudandarray indices
\n
"
);
indices
=
(
CudaNdarray
*
)
indices_obj
;
Py_INCREF
(
indices
);
}
else
if
(
PyArray_Check
(
indices_obj
))
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"CudaNdarray_TakeFrom: The indices must cudandarray with float32 value."
);
return
NULL
;
if
(
verbose
)
printf
(
"ndarray indices
\n
"
);
if
(
PyArray_TYPE
(
indices_obj
)
!=
NPY_INT32
)
{
PyErr_SetString
(
PyExc_TypeError
,
"CudaNdarray_TakeFrom: need a ndarray for indices with dtype int32"
);
return
NULL
;
}
if
(((
PyArrayObject
*
)
indices_obj
)
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_TypeError
,
"CudaNdarray_TakeFrom: need a CudaNdarray of indices with only 1 dimensions"
);
return
NULL
;
}
PyArray_Descr
*
float32_descr
=
PyArray_DescrFromType
(
NPY_FLOAT32
);
PyObject
*
indices_float32
=
NULL
;
indices_float32
=
PyArray_View
((
PyArrayObject
*
)
indices_obj
,
float32_descr
,
NULL
);
Py_DECREF
(
float32_descr
);
if
(
verbose
)
printf
(
"ndarray indices
\n
"
);
//indices_float32 = PyArray_Cast((PyArrayObject*)indices_obj,
// NPY_FLOAT32);
//Py_INCREF(indices_float32);
if
(
verbose
)
printf
(
"ndarray indices
\n
"
);
if
(
!
indices_float32
)
return
NULL
;
indices
=
(
CudaNdarray
*
)
CudaNdarray_New
();
if
(
verbose
)
printf
(
"ndarray after new
\n
"
);
if
(
!
indices
){
Py_DECREF
(
indices_float32
);
return
NULL
;
}
if
(
CudaNdarray_CopyFromArray
(
indices
,
(
PyArrayObject
*
)
indices_float32
)){
Py_DECREF
(
indices_float32
);
return
NULL
;
}
Py_DECREF
(
indices_float32
);
}
else
{
PyErr_SetString
(
PyExc_TypeError
,
"CudaNdarray_TakeFrom: need a CudaNdarray for indices"
);
return
NULL
;
}
if
(
verbose
)
{
printf
(
"indices used on the gpu
\n
"
);
fprint_CudaNdarray
(
stdout
,
indices
);
PyObject
*
used_indices
=
CudaNdarray_CreateArrayObj
(
indices
);
PyObject_Print
(
used_indices
,
stdout
,
0
);
Py_DECREF
(
used_indices
);
}
if
(
verbose
)
printf
(
"after print of object
\n
"
);
if
(
!
CudaNdarray_is_c_contiguous
(
indices
)
!=
0
)
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"CudaNdarray_TakeFrom: The indices must be contiguous in memory."
);
Py_DECREF
(
indices_obj
);
return
NULL
;
}
int
nb_indices
=
CudaNdarray_SIZE
((
CudaNdarray
*
)
indices
);
//Check argument axis
//TODO: implement the default and other axis
PyObject
*
axis_iobj
=
PyNumber_Long
(
axis_obj
);
if
(
!
axis_iobj
)
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"CudaNdarray_TakeFrom: axis must be convertisable to a long"
);
Py_DECREF
(
indices_obj
);
return
NULL
;
}
long
axis
=
PyInt_AsLong
(
axis_iobj
);
Py_DECREF
(
axis_iobj
);
axis_iobj
=
NULL
;
if
(
axis
!=
0
)
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"CudaNdarray_TakeFrom: only axis=0 is currently supported"
);
Py_DECREF
(
indices_obj
);
return
NULL
;
}
//Check argument out_obj
CudaNdarray
*
out
=
NULL
;
if
(
out_obj
&&
CudaNdarray_Check
(
out_obj
))
out
=
(
CudaNdarray
*
)
out_obj
;
if
(
out
&&
(
out
->
nd
!=
self
->
nd
||
CudaNdarray_HOST_DIMS
(
out
)[
0
]
!=
nb_indices
))
out
=
NULL
;
int
dims
[
self
->
nd
];
dims
[
0
]
=
nb_indices
;
for
(
int
i
=
1
;
i
<
self
->
nd
;
i
++
)
{
dims
[
i
]
=
CudaNdarray_HOST_DIMS
(
self
)[
i
];
if
(
out
&&
CudaNdarray_HOST_DIMS
(
out
)[
i
]
!=
dims
[
i
])
{
out
=
NULL
;
}
}
if
(
!
out
)
{
int
total_elements
=
nb_indices
;
for
(
int
i
=
1
;
i
<
self
->
nd
;
i
++
)
total_elements
*=
CudaNdarray_HOST_DIMS
(
self
)[
i
];
// total_elements now contains the size of the array, in reals
int
total_size
=
total_elements
*
sizeof
(
real
);
out
=
(
CudaNdarray
*
)
CudaNdarray_New
();
if
(
!
out
){
Py_DECREF
(
indices_obj
);
return
NULL
;
}
if
(
CudaNdarray_alloc_contiguous
(
out
,
self
->
nd
,
dims
))
{
Py_DECREF
(
out
);
Py_DECREF
(
indices_obj
);
return
NULL
;
}
}
else
{
Py_INCREF
(
out
);
}
//Check argument clipmode
if
(
clipmode_obj
)
{
char
*
clipmode
=
PyString_AsString
(
clipmode_obj
);
if
(
!
clipmode
){
Py_DECREF
(
indices_obj
);
Py_DECREF
(
out
);
return
NULL
;
}
if
(
strcmp
(
clipmode
,
"raise"
)
!=
0
)
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"CudaNdarray_TakeFrom: only the raise mode is currently supported"
);
Py_DECREF
(
indices_obj
);
Py_DECREF
(
out
);
return
NULL
;
}
Py_DECREF
(
clipmode_obj
);
}
void
(
*
k3
)(
const
int
,
const
int
,
const
int
,
const
float
*
,
float
*
,
const
int
,
const
int
,
const
int
,
const
float
*
,
const
int
,
const
int
,
const
int
,
const
int
,
int
*
);
k3
=
k_take_3
<
CPY
>
;
// Create the memory place that will store the error information.
if
(
err_var
==
NULL
)
{
err_var
=
(
int
*
)
device_malloc
(
sizeof
(
int
));
if
(
!
err_var
)
{
// PyErr set by device_malloc
Py_DECREF
(
indices_obj
);
Py_DECREF
(
out
);
return
NULL
;
}
cudaError_t
err
=
cudaMemset
((
void
*
)
err_var
,
0
,
sizeof
(
int
));
if
(
cudaSuccess
!=
err
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error setting device error code to 0. %s"
,
cudaGetErrorString
(
err
));
Py_DECREF
(
indices_obj
);
Py_DECREF
(
out
);
return
NULL
;
}
}
dim3
n_blocks
(
std
::
min
(
CudaNdarray_HOST_DIMS
(
out
)[
0
],
65535
),
1
,
1
);
switch
(
self
->
nd
)
{
case
1
:
{
dim3
n_threads
(
1
,
1
,
1
);
if
(
verbose
)
printf
(
"kernel config: (n_blocks.x=%d, n_blocks.y=%d,"
" n_threads.x=%i, n_threads.y=%i)
\n
"
,
n_blocks
.
x
,
n_blocks
.
y
,
n_threads
.
x
,
n_threads
.
y
);
k3
<<<
n_blocks
,
n_threads
>>>
(
dims
[
0
],
1
,
1
,
CudaNdarray_DEV_DATA
(
indices
),
CudaNdarray_DEV_DATA
(
out
),
CudaNdarray_HOST_STRIDES
(
out
)[
0
],
//strides
1
,
1
,
CudaNdarray_DEV_DATA
(
self
),
CudaNdarray_HOST_DIMS
(
self
)[
0
],
//For indices check
CudaNdarray_HOST_STRIDES
(
self
)[
0
],
//strides
1
,
1
,
err_var
);
}
break
;
case
2
:
{
dim3
n_threads
(
CudaNdarray_HOST_DIMS
(
out
)[
1
],
1
,
1
);
if
(
verbose
)
printf
(
"kernel config: (n_blocks.x=%d, n_blocks.y=%d,"
" n_threads.x=%i, n_threads.y=%i)
\n
"
,
n_blocks
.
x
,
n_blocks
.
y
,
n_threads
.
x
,
n_threads
.
y
);
k3
<<<
n_blocks
,
n_threads
>>>
(
dims
[
0
],
//dimensions
dims
[
1
],
1
,
CudaNdarray_DEV_DATA
(
indices
),
CudaNdarray_DEV_DATA
(
out
),
CudaNdarray_HOST_STRIDES
(
out
)[
0
],
//strides
CudaNdarray_HOST_STRIDES
(
out
)[
1
],
1
,
CudaNdarray_DEV_DATA
(
self
),
CudaNdarray_HOST_DIMS
(
self
)[
0
],
//For indices check
CudaNdarray_HOST_STRIDES
(
self
)[
0
],
//strides
CudaNdarray_HOST_STRIDES
(
self
)[
1
],
1
,
err_var
);
}
break
;
case
3
:
{
dim3
n_threads
(
CudaNdarray_HOST_DIMS
(
out
)[
1
],
CudaNdarray_HOST_DIMS
(
out
)[
2
],
1
);
if
(
verbose
)
printf
(
"kernel config: (n_blocks.x=%d, n_blocks.y=%d,"
" n_threads.x=%i, n_threads.y=%i)
\n
"
,
n_blocks
.
x
,
n_blocks
.
y
,
n_threads
.
x
,
n_threads
.
y
);
k3
<<<
n_blocks
,
n_threads
>>>
(
dims
[
0
],
//dimensions
dims
[
1
],
dims
[
2
],
CudaNdarray_DEV_DATA
(
indices
),
CudaNdarray_DEV_DATA
(
out
),
CudaNdarray_HOST_STRIDES
(
out
)[
0
],
//strides
CudaNdarray_HOST_STRIDES
(
out
)[
1
],
CudaNdarray_HOST_STRIDES
(
out
)[
2
],
CudaNdarray_DEV_DATA
(
self
),
CudaNdarray_HOST_DIMS
(
self
)[
0
],
//For indices check
CudaNdarray_HOST_STRIDES
(
self
)[
0
],
//strides
CudaNdarray_HOST_STRIDES
(
self
)[
1
],
CudaNdarray_HOST_STRIDES
(
self
)[
2
],
err_var
);
}
break
;
default
:
PyErr_SetString
(
PyExc_NotImplementedError
,
"CudaNdarray_TakeFrom: only input with 1, 2 or 3"
" dimensions are currently supported"
);
}
CNDA_THREAD_SYNC
;
cudaError_t
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Cuda error: %s: %s.
\n
"
,
"CudaNdarray_TakeFrom"
,
cudaGetErrorString
(
err
));
Py_DECREF
(
indices_obj
);
Py_DECREF
(
out
);
return
NULL
;
}
int
cpu_err_var
=-
10
;
err
=
cudaMemcpy
(
&
cpu_err_var
,
err_var
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
);
if
(
cudaSuccess
!=
err
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Cuda error: %s: %s when trying to get the error value.
\n
"
,
"CudaNdarray_TakeFrom"
,
cudaGetErrorString
(
err
));
Py_DECREF
(
indices_obj
);
Py_DECREF
(
out
);
return
NULL
;
}
if
(
cpu_err_var
!=
0
)
{
PyErr_Format
(
PyExc_IndexError
,
"Cuda error: %s: The error code on the gpu is %i.
\n
"
,
"CudaNdarray_TakeFrom"
,
cpu_err_var
);
// Must reset it to 0 to don't reset it before each use.
err
=
cudaMemset
((
void
*
)
err_var
,
0
,
sizeof
(
int
));
if
(
cudaSuccess
!=
err
)
{
PyErr_Format
(
PyExc_MemoryError
,
"Error setting device error code to 0 after having an index error. %s"
,
cudaGetErrorString
(
err
));
Py_DECREF
(
indices_obj
);
Py_DECREF
(
out
);
return
NULL
;
}
Py_DECREF
(
indices_obj
);
Py_DECREF
(
out
);
return
NULL
;
}
Py_DECREF
(
indices_obj
);
if
(
verbose
)
printf
(
"TAKE SUCCEDED
\n
"
);
return
(
PyObject
*
)
out
;
}
PyObject
*
CudaNdarray_SetStride
(
CudaNdarray
*
self
,
PyObject
*
args
)
PyObject
*
CudaNdarray_SetStride
(
CudaNdarray
*
self
,
PyObject
*
args
)
{
{
int
pos
,
stride
;
int
pos
,
stride
;
...
@@ -787,6 +1150,9 @@ static PyMethodDef CudaNdarray_methods[] =
...
@@ -787,6 +1150,9 @@ static PyMethodDef CudaNdarray_methods[] =
{
"_set_stride"
,
{
"_set_stride"
,
(
PyCFunction
)
CudaNdarray_SetStride
,
METH_VARARGS
,
(
PyCFunction
)
CudaNdarray_SetStride
,
METH_VARARGS
,
"For integer arguments (i, s), set the 'i'th stride to 's'"
},
"For integer arguments (i, s), set the 'i'th stride to 's'"
},
{
"take"
,
(
PyCFunction
)
CudaNdarray_TakeFrom
,
METH_VARARGS
,
"Equivalent of numpy.take"
},
{
"_set_shape_i"
,
{
"_set_shape_i"
,
(
PyCFunction
)
CudaNdarray_SetShapeI
,
METH_VARARGS
,
(
PyCFunction
)
CudaNdarray_SetShapeI
,
METH_VARARGS
,
"For integer arguments (i, s), set the 'i'th shape to 's'"
},
"For integer arguments (i, s), set the 'i'th shape to 's'"
},
...
@@ -869,14 +1235,6 @@ CudaNdarray_add(PyObject* py_self, PyObject * py_other)
...
@@ -869,14 +1235,6 @@ CudaNdarray_add(PyObject* py_self, PyObject * py_other)
return
(
PyObject
*
)
rval
;
return
(
PyObject
*
)
rval
;
}
}
enum
operator_t
{
IADD
=
0
,
IDIV
,
CPY
,
N_ELEMWISE_OPS
// What this mean? It is not used
};
template
<
int
operator_num
>
template
<
int
operator_num
>
__global__
void
k_ielem_3
(
const
int
d0
,
const
int
d1
,
const
int
d2
,
__global__
void
k_ielem_3
(
const
int
d0
,
const
int
d1
,
const
int
d2
,
float
*
a
,
const
int
sA0
,
const
int
sA1
,
const
int
sA2
,
float
*
a
,
const
int
sA0
,
const
int
sA1
,
const
int
sA2
,
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
6c955ecf
...
@@ -338,6 +338,12 @@ DllExport int CudaNdarray_reduce_min(CudaNdarray * self, CudaNdarray * A);
...
@@ -338,6 +338,12 @@ DllExport int CudaNdarray_reduce_min(CudaNdarray * self, CudaNdarray * A);
DllExport
int
CudaNdarray_reduce_max
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_reduce_max
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_dimshuffle
(
CudaNdarray
*
self
,
unsigned
int
len
,
const
int
*
pattern
);
DllExport
int
CudaNdarray_dimshuffle
(
CudaNdarray
*
self
,
unsigned
int
len
,
const
int
*
pattern
);
//PyObject* PyArray_TakeFrom(PyArrayObject* self, PyObject* indices, int axis, PyArrayObject* ret, NPY_CLIPMODE clipmode)
//PyObject*
//CudaNdarray_TakeFrom(CudaNdarray* self, PyObject* indices, int axis,
// PyArrayObject* ret, NPY_CLIPMODE clipmode);
PyObject
*
CudaNdarray_TakeFrom
(
CudaNdarray
*
self
,
PyObject
*
args
);
static
void
fprint_CudaNdarray
(
FILE
*
fd
,
const
CudaNdarray
*
self
);
static
void
fprint_CudaNdarray
(
FILE
*
fd
,
const
CudaNdarray
*
self
);
...
...
theano/sandbox/cuda/tests/test_basic_ops.py
浏览文件 @
6c955ecf
import
sys
,
time
,
unittest
import
copy
import
sys
import
time
import
unittest
from
theano.compile.pfunc
import
pfunc
from
theano.compile.pfunc
import
pfunc
from
theano
import
tensor
from
theano
import
tensor
...
@@ -846,6 +849,47 @@ class T_subtensor(theano.tensor.tests.test_basic.T_subtensor):
...
@@ -846,6 +849,47 @@ class T_subtensor(theano.tensor.tests.test_basic.T_subtensor):
return
super
(
theano
.
tensor
.
tests
.
test_basic
.
T_subtensor
,
return
super
(
theano
.
tensor
.
tests
.
test_basic
.
T_subtensor
,
self
)
.
__init__
(
name
)
self
)
.
__init__
(
name
)
def
test_adv_sub1_fast
(
self
):
""" We check that we correctly used the fast version"""
rand
=
numpy
.
random
.
rand
for
data
,
idx
,
fast
in
[(
rand
(
70000
),
range
(
70000
),
True
),
(
rand
(
70000
,
5
),
range
(
70000
),
True
),
(
rand
(
70000
,
2
,
3
),
range
(
70000
),
True
),
(
rand
(
4
,
5
),
[
2
,
3
],
True
),
(
rand
(
4
,
2
,
3
),
[
0
,
3
],
True
),
(
rand
(
4
,
2
,
3
),
[
3
,
3
,
1
,
1
,
2
,
2
,
0
,
0
],
True
),
(
rand
(
4
,
2
,
3
),
[
3
,
3
,
1
,
1
,
2
,
2
,
0
,
0
,
-
1
,
-
2
,
-
3
,
-
4
],
True
),
# Test 4 dims as gpu. code use another algo
# in that case. This new algo is not as much
# optimized for that case.
(
rand
(
4
,
4
,
2
,
3
),
[
3
,
3
,
1
,
1
,
2
,
2
,
0
,
0
,
-
1
,
-
2
,
-
3
,
-
4
],
False
),
]:
data
=
numpy
.
asarray
(
data
,
dtype
=
self
.
dtype
)
n
=
self
.
shared
(
data
)
# Test with c_contiguous input
t
=
self
.
adv_sub1
()(
n
,
idx
)
t
.
owner
.
op
.
assert_fast
=
True
# input c_contiguous, so we reshape
val
=
self
.
eval_output_and_check
(
t
,
list
=
True
)
val
=
numpy
.
asarray
(
val
)
good
=
data
[
idx
]
self
.
assertTrue
(
val
.
ndim
==
data
.
ndim
)
self
.
assertTrue
(
numpy
.
allclose
(
val
,
good
),
(
val
,
good
))
# Test with input strided
t
=
self
.
adv_sub1
()(
n
[::
-
1
],
idx
)
t
.
owner
.
op
.
assert_fast
=
fast
val
=
theano
.
function
([],
t
,
mode
=
self
.
mode
)()
val
=
numpy
.
asarray
(
val
)
good
=
data
[::
-
1
][
idx
]
self
.
assertTrue
(
val
.
ndim
==
data
.
ndim
)
self
.
assertTrue
(
numpy
.
allclose
(
val
,
good
),
(
val
,
good
))
def
test_advinc_subtensor1
():
def
test_advinc_subtensor1
():
""" Test the second case in the opt local_gpu_advanced_incsubtensor1 """
""" Test the second case in the opt local_gpu_advanced_incsubtensor1 """
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论