Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
96477de4
提交
96477de4
authored
12月 09, 2013
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Copy GpuSubtensor to the new back-end, enable its test.
上级
61ad2da0
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
131 行增加
和
3 行删除
+131
-3
subtensor.py
theano/sandbox/gpuarray/subtensor.py
+126
-1
test_subtensor.py
theano/sandbox/gpuarray/tests/test_subtensor.py
+5
-2
没有找到文件。
theano/sandbox/gpuarray/subtensor.py
浏览文件 @
96477de4
...
@@ -4,7 +4,7 @@ import numpy
...
@@ -4,7 +4,7 @@ import numpy
import
theano
import
theano
from
theano
import
tensor
,
gof
from
theano
import
tensor
,
gof
from
theano.tensor.subtensor
import
Subtensor
,
get_idx_list
from
theano.tensor.subtensor
import
IncSubtensor
,
Subtensor
,
get_idx_list
from
theano.gof.python25
import
all
,
any
from
theano.gof.python25
import
all
,
any
...
@@ -154,3 +154,128 @@ class GpuSubtensor(HideC, Subtensor):
...
@@ -154,3 +154,128 @@ class GpuSubtensor(HideC, Subtensor):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
5
,)
return
(
5
,)
class
GpuIncSubtensor
(
HideC
,
IncSubtensor
):
"""
Implement IncSubtensor on the gpu.
Note: The optimization to make this inplace is in tensor/opt.
The same optimization handles IncSubtensor and GpuIncSubtensor.
This Op has c_code too; it inherits tensor.IncSubtensor's c_code.
The helper methods like do_type_checking, copy_of_x, etc. specialize
the c_code for this Op.
"""
def
make_node
(
self
,
x
,
y
,
*
inputs
):
x
=
as_cuda_ndarray_variable
(
x
)
y
=
as_cuda_ndarray_variable
(
y
)
rval
=
tensor
.
IncSubtensor
.
make_node
(
self
,
x
,
y
,
*
inputs
)
return
Apply
(
self
,
[
x
,
y
]
+
rval
.
inputs
[
2
:],
[
x
.
type
()])
def
do_type_checking
(
self
,
node
):
""" Should raise NotImplementedError if c_code does not support
the types involved in this node.
"""
if
not
isinstance
(
node
.
inputs
[
0
]
.
type
,
CudaNdarrayType
):
raise
NotImplementedError
()
def
copy_of_x
(
self
,
x
):
"""
:param x: a string giving the name of a C variable
pointing to an array
:return: C code expression to make a copy of x
Base class uses `PyArrayObject *`, subclasses may override for
different types of arrays.
"""
return
"""(CudaNdarray*) CudaNdarray_Copy(
%(x)
s)"""
%
locals
()
def
decl_view
(
self
):
return
"CudaNdarray* zview = NULL;"
def
make_view_array
(
self
,
x
,
view_ndim
):
"""
:param x: a string identifying an array to be viewed
:param view_ndim: a string specifying the number of dimensions
to have in the view
This doesn't need to actually set up the view with the
right indexing; we'll do that manually later.
"""
ret
=
"""zview = (CudaNdarray*) CudaNdarray_New(
%(view_ndim)
s);
if (CudaNdarray_set_device_data(
zview,
CudaNdarray_DEV_DATA(
%(x)
s) + xview_offset/4,
(PyObject*)
%(x)
s))
{
zview = NULL;
PyErr_Format(PyExc_RuntimeError,
"GpuSubtensor is not able to set the"
" devdata field of the view");
}else{
cnda_mark_dev_structure_dirty(zview);
for(int idx=0;idx <
%(view_ndim)
s; idx++){
if(xview_dims[idx]==1)
CudaNdarray_set_stride(zview, idx, 0);
else
CudaNdarray_set_stride(zview, idx, xview_strides[idx]);
CudaNdarray_set_dim(zview, idx, xview_dims[idx]);
}
}
"""
%
locals
()
return
ret
def
get_helper_c_code_args
(
self
):
""" Return a dictionary of arguments to use with helper_c_code"""
return
{
'c_prefix'
:
'CudaNdarray'
,
'strides_mul'
:
4
}
def
copy_into
(
self
,
view
,
source
):
"""
view: string, C code expression for an array
source: string, C code expression for an array
returns a C code expression to copy source into view, and
return 0 on success
"""
return
"""CudaNdarray_CopyFromCudaNdarray(
%(view)
s,
%(source)
s)"""
%
locals
()
def
set_view_base
(
self
,
x
,
fail
):
return
"""
//Set the base only now
if(CudaNdarray_set_device_data(zview, CudaNdarray_DEV_DATA(zview),
%(x)
s)){
PyErr_Format(PyExc_RuntimeError,
"GpuSubtensor is not able to set"
" the base of the view array");
Py_XDECREF(zview);
%(fail)
s;
}"""
%
locals
()
def
add_to_zview
(
self
,
x
,
fail
):
return
"""
PyObject * add_result = CudaNdarray_inplace_add((PyObject *) zview,
(PyObject *) py_
%(x)
s);
if (! add_result )
{
Py_DECREF(zview);
%(fail)
s;
}
else
{
Py_DECREF(add_result);
}
"""
%
locals
()
def
c_code_cache_version
(
self
):
parent_version
=
super
(
GpuIncSubtensor
,
self
)
.
c_code_cache_version
()
if
parent_version
:
return
parent_version
+
(
0
,)
return
()
theano/sandbox/gpuarray/tests/test_subtensor.py
浏览文件 @
96477de4
from
theano.tensor.tests.test_subtensor
import
T_subtensor
from
theano.tensor.tests.test_subtensor
import
T_subtensor
from
theano.sandbox.gpuarray.basic_ops
import
(
HostFromGpu
,
GpuFromHost
)
from
theano.sandbox.gpuarray.basic_ops
import
(
HostFromGpu
,
GpuFromHost
)
from
theano.sandbox.gpuarray.subtensor
import
GpuSubtensor
from
theano.sandbox.gpuarray.subtensor
import
Gpu
IncSubtensor
,
Gpu
Subtensor
from
theano.sandbox.gpuarray.type
import
gpuarray_shared_constructor
from
theano.sandbox.gpuarray.type
import
gpuarray_shared_constructor
...
@@ -11,6 +11,7 @@ from theano.compile import DeepCopyOp
...
@@ -11,6 +11,7 @@ from theano.compile import DeepCopyOp
from
theano
import
tensor
from
theano
import
tensor
class
G_subtensor
(
T_subtensor
):
class
G_subtensor
(
T_subtensor
):
def
shortDescription
(
self
):
def
shortDescription
(
self
):
return
None
return
None
...
@@ -19,8 +20,10 @@ class G_subtensor(T_subtensor):
...
@@ -19,8 +20,10 @@ class G_subtensor(T_subtensor):
T_subtensor
.
__init__
(
self
,
name
,
T_subtensor
.
__init__
(
self
,
name
,
shared
=
gpuarray_shared_constructor
,
shared
=
gpuarray_shared_constructor
,
sub
=
GpuSubtensor
,
sub
=
GpuSubtensor
,
inc_sub
=
GpuIncSubtensor
,
mode
=
mode_with_gpu
,
mode
=
mode_with_gpu
,
# avoid errors with limited devices
# avoid errors with limited devices
dtype
=
'float32'
,
dtype
=
'float32'
,
ignore_topo
=
(
HostFromGpu
,
GpuFromHost
,
DeepCopyOp
))
ignore_topo
=
(
HostFromGpu
,
GpuFromHost
,
DeepCopyOp
))
assert
self
.
sub
==
GpuSubtensor
assert
self
.
sub
==
GpuSubtensor
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论