Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1d46d73d
提交
1d46d73d
authored
4月 07, 2014
作者:
Pierre Luc Carrier
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Copied GpuAdvancedIncSubtensor1() and GpuAdvancedIncSubtensor1_dev() ops and…
Copied GpuAdvancedIncSubtensor1() and GpuAdvancedIncSubtensor1_dev() ops and their tests to sandbox/gpuarray
上级
90b5a114
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
326 行增加
和
0 行删除
+326
-0
subtensor.py
theano/sandbox/gpuarray/subtensor.py
+305
-0
test_subtensor.py
theano/sandbox/gpuarray/tests/test_subtensor.py
+21
-0
没有找到文件。
theano/sandbox/gpuarray/subtensor.py
浏览文件 @
1d46d73d
...
@@ -357,3 +357,307 @@ class GpuIncSubtensor(IncSubtensor):
...
@@ -357,3 +357,307 @@ class GpuIncSubtensor(IncSubtensor):
if
not
parent_version
or
not
elemwise_version
:
if
not
parent_version
or
not
elemwise_version
:
return
return
return
parent_version
+
elemwise_version
+
(
0
,)
return
parent_version
+
elemwise_version
+
(
0
,)
class
GpuAdvancedIncSubtensor1
(
tensor
.
AdvancedIncSubtensor1
,
GpuOp
):
"""
Implement AdvancedIncSubtensor1 on the gpu.
"""
def
make_node
(
self
,
x
,
y
,
ilist
):
x_
=
as_cuda_ndarray_variable
(
x
)
y_
=
as_cuda_ndarray_variable
(
y
)
ilist_
=
tensor
.
as_tensor_variable
(
ilist
)
assert
x_
.
type
.
dtype
==
y_
.
type
.
dtype
assert
x_
.
type
.
ndim
>=
y_
.
type
.
ndim
if
ilist_
.
type
.
dtype
[:
3
]
not
in
(
'int'
,
'uin'
):
raise
TypeError
(
'index must be integers'
)
if
ilist_
.
type
.
broadcastable
!=
(
False
,):
raise
TypeError
(
'index must be vector'
)
if
x_
.
type
.
ndim
==
0
:
raise
TypeError
(
'cannot index into a scalar'
)
if
x_
.
type
.
broadcastable
[
0
]:
# the caller should have made a copy of x len(ilist) times
raise
TypeError
(
'cannot index into a broadcastable dimension'
)
return
Apply
(
self
,
[
x_
,
y_
,
ilist_
],
[
x_
.
type
()])
# CudaNdarray_Subscript() doesn't support Advanced slicing.
# But we can't use the parent version that loops on each index
# as we also need to loop when set_instead_of_inc is True and the
# parent doesn't loop in that case.
def
perform
(
self
,
node
,
inp
,
out_
):
# TODO opt to make this inplace
x
,
y
,
idx
=
inp
out
,
=
out_
if
not
self
.
inplace
:
x
=
x
.
copy
()
if
self
.
set_instead_of_inc
:
# CudaNdarray __setitem__ doesn't do broadcast nor support
# list of index.
assert
y
.
ndim
<=
x
.
ndim
# Should be guaranteed by `make_node`
if
y
.
ndim
==
x
.
ndim
:
assert
len
(
y
)
==
len
(
idx
)
for
(
j
,
i
)
in
enumerate
(
idx
):
x
[
i
]
=
y
[
j
]
else
:
for
i
in
idx
:
x
[
i
]
=
y
else
:
# If `y` has as many dimensions as `x`, then we want to iterate
# jointly on `x` and `y`. Otherwise, it means `y` should be
# broadcasted to fill all relevant rows of `x`.
assert
y
.
ndim
<=
x
.
ndim
# Should be guaranteed by `make_node`
if
y
.
ndim
==
x
.
ndim
:
assert
len
(
y
)
==
len
(
idx
)
for
(
j
,
i
)
in
enumerate
(
idx
):
x
[
i
]
+=
y
[
j
]
else
:
for
i
in
idx
:
x
[
i
]
+=
y
out
[
0
]
=
x
def
c_code_cache_version
(
self
):
return
(
3
,)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
if
(
self
.
set_instead_of_inc
)
or
\
(
node
.
inputs
[
0
]
.
ndim
!=
node
.
inputs
[
1
]
.
ndim
):
raise
NotImplementedError
(
"This case does not have C code yet."
)
x
=
inputs
[
0
]
y
=
inputs
[
1
]
ind
=
inputs
[
2
]
out
=
outputs
[
0
]
fail
=
sub
[
'fail'
]
inplace
=
int
(
self
.
inplace
)
return
"""
PyObject *x_obj, *y_obj, *row_x, *row_y;
PyObject *x_rowind_obj, *y_rowind_obj;
dtype_
%(ind)
s *p_index;
int num_indices, j;
int ret;
num_indices = PyArray_SIZE(
%(ind)
s);
if ((num_indices - 1) > LONG_MAX) {
PyErr_Format(PyExc_AssertionError,
"num_indices
%%
d exceeds LONG_MAX + 1", num_indices);
%(fail)
s;
}
Py_XDECREF(
%(out)
s);
if (!
%(inplace)
s) {
%(out)
s = (CudaNdarray*)CudaNdarray_Copy(
%(x)
s);
} else {
%(out)
s =
%(x)
s;
Py_XINCREF(
%(out)
s);
}
x_obj = (PyObject*)CudaNdarray_View(
%(out)
s);
y_obj = (PyObject*)CudaNdarray_View(
%(y)
s);
for (j = 0;j < num_indices; j++) {
p_index = (dtype_
%(ind)
s *)PyArray_GETPTR1(
%(ind)
s, j);
x_rowind_obj = PyInt_FromLong(*p_index);
if (PyInt_AsLong(x_rowind_obj) != (*p_index)) {
PyErr_Format(PyExc_AssertionError,
"Error in converting row index to integer from long");
// Dec Ref what ever we have increfed or allocated so far
// We deallocate objects exactly in the reverse order they were allocated.
Py_XDECREF(x_rowind_obj);
Py_XDECREF(y_obj);
Py_XDECREF(x_obj);
%(fail)
s;
}
y_rowind_obj = PyInt_FromLong(j);
row_x = CudaNdarray_Subscript(x_obj, x_rowind_obj);
row_y = CudaNdarray_Subscript(y_obj, y_rowind_obj);
if ((row_x == NULL) || (row_y == NULL)) {
Py_XDECREF(row_y);
Py_XDECREF(row_x);
Py_XDECREF(y_rowind_obj);
Py_XDECREF(x_rowind_obj);
Py_XDECREF(y_obj);
Py_XDECREF(x_obj);
%(fail)
s;
}
ret = CudaNdarray_inplace_elemwise(row_x, row_y, IADD);
if (ret != 0) {
Py_XDECREF(row_y);
Py_XDECREF(row_x);
Py_XDECREF(y_rowind_obj);
Py_XDECREF(x_rowind_obj);
Py_XDECREF(y_obj);
Py_XDECREF(x_obj);
%(fail)
s;
}
Py_XDECREF(row_y);
Py_XDECREF(row_x);
Py_XDECREF(y_rowind_obj);
Py_XDECREF(x_rowind_obj);
}
Py_XDECREF(y_obj);
Py_XDECREF(x_obj);
if (!
%(out)
s) {
%(fail)
s
}
"""
%
locals
()
class
GpuAdvancedIncSubtensor1_dev20
(
GpuAdvancedIncSubtensor1
):
"""Implement AdvancedIncSubtensor1 on the gpu, but use function
only avail on compute capability 2.0 and more recent.
"""
def
make_node
(
self
,
x
,
y
,
ilist
):
"""It defer from GpuAdvancedIncSubtensor1 in that it make sure
the index are of type long.
"""
x_
=
as_cuda_ndarray_variable
(
x
)
y_
=
as_cuda_ndarray_variable
(
y
)
ilist_
=
tensor
.
as_tensor_variable
(
ilist
)
convert_map
=
{
8
:
tensor
.
basic
.
_convert_to_int8
,
16
:
tensor
.
basic
.
_convert_to_int16
,
32
:
tensor
.
basic
.
_convert_to_int32
,
64
:
tensor
.
basic
.
_convert_to_int64
}
intwidth
=
theano
.
gof
.
compiledir
.
python_int_bitwidth
()
ilist_
=
convert_map
[
intwidth
](
ilist_
)
assert
x_
.
type
.
dtype
==
y_
.
type
.
dtype
assert
x_
.
type
.
ndim
>=
y_
.
type
.
ndim
if
ilist_
.
type
.
dtype
[:
3
]
not
in
(
'int'
,
'uin'
):
raise
TypeError
(
'index must be integers'
)
if
ilist_
.
type
.
broadcastable
!=
(
False
,):
raise
TypeError
(
'index must be vector'
)
if
x_
.
type
.
ndim
==
0
:
raise
TypeError
(
'cannot index into a scalar'
)
if
x_
.
type
.
broadcastable
[
0
]:
# the caller should have made a copy of x len(ilist) times
raise
TypeError
(
'cannot index into a broadcastable dimension'
)
return
Apply
(
self
,
[
x_
,
y_
,
ilist_
],
[
x_
.
type
()])
def
c_code_cache_version
(
self
):
return
(
2
,)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
active_device_no
=
theano
.
sandbox
.
cuda
.
active_device_number
()
compute_capability
=
device_properties
(
active_device_no
)[
'major'
]
if
((
self
.
set_instead_of_inc
)
or
(
node
.
inputs
[
0
]
.
ndim
!=
node
.
inputs
[
1
]
.
ndim
)
or
(
node
.
inputs
[
0
]
.
ndim
!=
2
)
or
(
compute_capability
<
2
)):
raise
NotImplementedError
(
"This case does not have C code yet."
)
x
=
inputs
[
0
]
y
=
inputs
[
1
]
ind
=
inputs
[
2
]
out
=
outputs
[
0
]
fail
=
sub
[
'fail'
]
inplace
=
int
(
self
.
inplace
)
return
"""
Py_XDECREF(
%(out)
s);
if (!
%(inplace)
s) {
%(out)
s = (CudaNdarray*)CudaNdarray_Copy(
%(x)
s);
} else {
%(out)
s =
%(x)
s;
Py_XINCREF(
%(out)
s);
}
CudaNdarray_vector_add_fast(
%(out)
s,
%(y)
s,
%(ind)
s);
if (!
%(out)
s) {
%(fail)
s
}
"""
%
locals
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
return
"""
__global__ void k_vector_add_fast(int numRowsX,
int numColsX,
int stridesX0,
int stridesX1,
float *X,
int numRowsY,
int numColsY,
int stridesY0,
int stridesY1,
float *Y ,
long *d_indices_arr,
int num)
{
for (int i = (blockIdx.x); i < num; i += gridDim.x)
{
for(int j = (threadIdx.x); j < numColsX;j += blockDim.x)
{
int x_row = d_indices_arr[i];
int y_row = i;
atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)], Y[(y_row * stridesY0) + (j * stridesY1)]);
}
}
return;
}
void CudaNdarray_vector_add_fast(CudaNdarray* py_self, CudaNdarray* py_other, PyArrayObject *indices_arr)
{
const int *shapeX = CudaNdarray_HOST_DIMS(py_self);
const int *shapeY = CudaNdarray_HOST_DIMS(py_other);
const int *strX = CudaNdarray_HOST_STRIDES(py_self);
const int *strY = CudaNdarray_HOST_STRIDES(py_other);
unsigned int size = (unsigned int)PyArray_SIZE(indices_arr);
unsigned int numcolsX = shapeX[1];
unsigned int num_threads_per_block = std::min(numcolsX, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
unsigned int num_blocks = std::min(size ,(unsigned int)NUM_VECTOR_OP_BLOCKS);
dim3 n_blocks(num_blocks);
dim3 n_threads(num_threads_per_block);
long *d_indices_arr = NULL;
PyArrayObject *cpu_indices_arr = PyArray_GETCONTIGUOUS(indices_arr);
d_indices_arr = (long*)device_malloc(PyArray_NBYTES(cpu_indices_arr));
assert(d_indices_arr);
cudaError_t err = cudaMemcpy(d_indices_arr,
PyArray_DATA(cpu_indices_arr),
PyArray_NBYTES(cpu_indices_arr),
cudaMemcpyHostToDevice);
assert(err == cudaSuccess);
k_vector_add_fast<<<n_blocks, n_threads>>>(shapeX[0],
shapeX[1],
strX[0],
strX[1],
CudaNdarray_DEV_DATA(py_self),
shapeY[0],
shapeY[1],
strY[0],
strY[1],
CudaNdarray_DEV_DATA(py_other),
d_indices_arr,
PyArray_SIZE(indices_arr)
);
device_free(d_indices_arr);
Py_XDECREF(cpu_indices_arr);
return;
}
"""
%
locals
()
\ No newline at end of file
theano/sandbox/gpuarray/tests/test_subtensor.py
浏览文件 @
1d46d73d
...
@@ -29,3 +29,24 @@ class G_subtensor(T_subtensor):
...
@@ -29,3 +29,24 @@ class G_subtensor(T_subtensor):
# GPU opt can't run in fast_compile only.
# GPU opt can't run in fast_compile only.
self
.
fast_compile
=
False
self
.
fast_compile
=
False
assert
self
.
sub
==
GpuSubtensor
assert
self
.
sub
==
GpuSubtensor
def
test_advinc_subtensor1
():
""" Test the second case in the opt local_gpu_advanced_incsubtensor1 """
for
shp
in
[(
3
,
3
),
(
3
,
3
,
3
)]:
shared
=
cuda
.
shared_constructor
xval
=
numpy
.
arange
(
numpy
.
prod
(
shp
),
dtype
=
'float32'
)
.
reshape
(
shp
)
+
1
yval
=
numpy
.
empty
((
2
,)
+
shp
[
1
:],
dtype
=
'float32'
)
yval
[:]
=
10
x
=
shared
(
xval
,
name
=
'x'
)
y
=
T
.
tensor
(
dtype
=
'float32'
,
broadcastable
=
(
False
,)
*
len
(
shp
),
name
=
'y'
)
expr
=
T
.
advanced_inc_subtensor1
(
x
,
y
,
[
0
,
2
])
f
=
theano
.
function
([
y
],
expr
,
mode
=
mode_with_gpu
)
assert
sum
([
isinstance
(
node
.
op
,
cuda
.
GpuAdvancedIncSubtensor1
)
for
node
in
f
.
maker
.
fgraph
.
toposort
()])
==
1
rval
=
f
(
yval
)
rep
=
xval
.
copy
()
rep
[[
0
,
2
]]
+=
yval
assert
numpy
.
allclose
(
rval
,
rep
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论