Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
975e0d2b
提交
975e0d2b
authored
9月 16, 2015
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3227 from abergeron/gpua_advsub1
Implement GpuAdvancedSubtensor1 for gpuarray
上级
c13853ad
e3474eda
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
38 行增加
和
14 行删除
+38
-14
gpuarray_helper.h
theano/sandbox/gpuarray/gpuarray_helper.h
+11
-4
opt.py
theano/sandbox/gpuarray/opt.py
+21
-5
subtensor.py
theano/sandbox/gpuarray/subtensor.py
+0
-0
test_subtensor.py
theano/sandbox/gpuarray/tests/test_subtensor.py
+2
-0
test_subtensor.py
theano/tensor/tests/test_subtensor.py
+4
-4
test_flake8.py
theano/tests/test_flake8.py
+0
-1
没有找到文件。
theano/sandbox/gpuarray/gpuarray_helper.h
浏览文件 @
975e0d2b
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
#include <string.h>
#include <string.h>
#include <gpuarray_api.h>
#include <gpuarray_api.h>
#include <numpy_compat.h>
#include <numpy_compat.h>
#include <gpuarray/util.h>
static
int
theano_size_check
(
PyGpuArrayObject
*
a
,
unsigned
int
nd
,
static
int
theano_size_check
(
PyGpuArrayObject
*
a
,
unsigned
int
nd
,
const
size_t
*
dims
,
int
typecode
)
{
const
size_t
*
dims
,
int
typecode
)
{
...
@@ -42,9 +44,14 @@ static PyGpuArrayObject *theano_try_copy(PyGpuArrayObject *out,
...
@@ -42,9 +44,14 @@ static PyGpuArrayObject *theano_try_copy(PyGpuArrayObject *out,
return
out
;
return
out
;
}
}
/* This is guaranteed to work and return the raw CUDA/OpenCL object on
static
inline
void
*
PyGpuArray_DEV_DATA
(
PyGpuArrayObject
*
a
)
{
* all recent (as of June 2015) version of libgpuarray. This is also
/* This is guaranteed to work and return the raw CUDA/OpenCL object on
* promised to keep working in future versions. */
* all recent (as of June 2015) version of libgpuarray. This is also
#define PyGpuArray_DEV_DATA(ary) (*(void **)((ary)->ga.data))
* promised to keep working in future versions. */
char
*
p
=
*
((
char
**
)
a
->
ga
.
data
);
/* This only works on cuda since we have a real pointer. */
return
(
void
*
)(
p
+
a
->
ga
.
offset
);
}
#endif
#endif
theano/sandbox/gpuarray/opt.py
浏览文件 @
975e0d2b
...
@@ -35,6 +35,7 @@ from .nnet import (GpuCrossentropySoftmaxArgmax1HotWithBias,
...
@@ -35,6 +35,7 @@ from .nnet import (GpuCrossentropySoftmaxArgmax1HotWithBias,
from
.elemwise
import
(
GpuElemwise
,
GpuDimShuffle
,
GpuCAReduceCuda
,
from
.elemwise
import
(
GpuElemwise
,
GpuDimShuffle
,
GpuCAReduceCuda
,
GpuCAReduceCPY
)
GpuCAReduceCPY
)
from
.subtensor
import
(
GpuIncSubtensor
,
GpuSubtensor
,
from
.subtensor
import
(
GpuIncSubtensor
,
GpuSubtensor
,
GpuAdvancedSubtensor1
,
GpuAdvancedIncSubtensor1
,
GpuAdvancedIncSubtensor1
,
GpuAdvancedIncSubtensor1_dev20
)
GpuAdvancedIncSubtensor1_dev20
)
...
@@ -488,6 +489,12 @@ def local_gpua_incsubtensor(node):
...
@@ -488,6 +489,12 @@ def local_gpua_incsubtensor(node):
node
.
op
.
destroyhandler_tolerate_aliased
)
node
.
op
.
destroyhandler_tolerate_aliased
)
@register_opt
(
'fast_compile'
)
@op_lifter
([
tensor
.
AdvancedSubtensor1
])
def
local_gpua_advanced_subtensor
(
node
):
return
GpuAdvancedSubtensor1
()
@register_opt
(
'fast_compile'
)
@register_opt
(
'fast_compile'
)
@op_lifter
([
tensor
.
AdvancedIncSubtensor1
])
@op_lifter
([
tensor
.
AdvancedIncSubtensor1
])
def
local_gpua_advanced_incsubtensor
(
node
):
def
local_gpua_advanced_incsubtensor
(
node
):
...
@@ -496,7 +503,16 @@ def local_gpua_advanced_incsubtensor(node):
...
@@ -496,7 +503,16 @@ def local_gpua_advanced_incsubtensor(node):
if
pygpu
.
get_default_context
()
.
kind
!=
"cuda"
:
if
pygpu
.
get_default_context
()
.
kind
!=
"cuda"
:
return
None
return
None
x
,
y
=
node
.
inputs
[
0
:
2
]
x
,
y
,
ilist
=
node
.
inputs
# Gpu Ops needs both inputs to have the same dtype
if
(
x
.
type
.
dtype
!=
y
.
type
.
dtype
):
dtype
=
scalar
.
upcast
(
x
.
type
.
dtype
,
y
.
type
.
dtype
)
if
x
.
type
.
dtype
!=
dtype
:
x
=
tensor
.
cast
(
x
,
dtype
)
if
y
.
type
.
dtype
!=
dtype
:
y
=
tensor
.
cast
(
y
,
dtype
)
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
active_device_no
=
theano
.
sandbox
.
cuda
.
active_device_number
()
active_device_no
=
theano
.
sandbox
.
cuda
.
active_device_number
()
device_properties
=
theano
.
sandbox
.
cuda
.
device_properties
device_properties
=
theano
.
sandbox
.
cuda
.
device_properties
...
@@ -504,11 +520,11 @@ def local_gpua_advanced_incsubtensor(node):
...
@@ -504,11 +520,11 @@ def local_gpua_advanced_incsubtensor(node):
compute_capability
=
device_properties
(
active_device_no
)[
'major'
]
compute_capability
=
device_properties
(
active_device_no
)[
'major'
]
if
(
compute_capability
<
2
or
x
.
ndim
!=
2
or
y
.
ndim
!=
2
):
if
(
compute_capability
<
2
or
x
.
ndim
!=
2
or
y
.
ndim
!=
2
):
return
GpuAdvancedIncSubtensor1
(
return
[
GpuAdvancedIncSubtensor1
(
set_instead_of_inc
=
set_instead_of_inc
)
set_instead_of_inc
=
set_instead_of_inc
)(
x
,
y
,
ilist
)]
else
:
else
:
return
GpuAdvancedIncSubtensor1_dev20
(
return
[
GpuAdvancedIncSubtensor1_dev20
(
set_instead_of_inc
=
set_instead_of_inc
)
set_instead_of_inc
=
set_instead_of_inc
)(
x
,
y
,
ilist
)]
@register_opt
(
'fast_compile'
)
@register_opt
(
'fast_compile'
)
...
...
theano/sandbox/gpuarray/subtensor.py
浏览文件 @
975e0d2b
差异被折叠。
点击展开。
theano/sandbox/gpuarray/tests/test_subtensor.py
浏览文件 @
975e0d2b
...
@@ -7,6 +7,7 @@ from theano.tensor.tests import test_subtensor
...
@@ -7,6 +7,7 @@ from theano.tensor.tests import test_subtensor
from
..basic_ops
import
HostFromGpu
,
GpuFromHost
from
..basic_ops
import
HostFromGpu
,
GpuFromHost
from
..subtensor
import
(
GpuIncSubtensor
,
GpuSubtensor
,
from
..subtensor
import
(
GpuIncSubtensor
,
GpuSubtensor
,
GpuAdvancedSubtensor1
,
GpuAdvancedIncSubtensor1
)
GpuAdvancedIncSubtensor1
)
from
..type
import
gpuarray_shared_constructor
from
..type
import
gpuarray_shared_constructor
...
@@ -24,6 +25,7 @@ class G_subtensor(test_subtensor.T_subtensor):
...
@@ -24,6 +25,7 @@ class G_subtensor(test_subtensor.T_subtensor):
shared
=
gpuarray_shared_constructor
,
shared
=
gpuarray_shared_constructor
,
sub
=
GpuSubtensor
,
sub
=
GpuSubtensor
,
inc_sub
=
GpuIncSubtensor
,
inc_sub
=
GpuIncSubtensor
,
adv_sub1
=
GpuAdvancedSubtensor1
,
adv_incsub1
=
GpuAdvancedIncSubtensor1
,
adv_incsub1
=
GpuAdvancedIncSubtensor1
,
mode
=
mode_with_gpu
,
mode
=
mode_with_gpu
,
# avoid errors with limited devices
# avoid errors with limited devices
...
...
theano/tensor/tests/test_subtensor.py
浏览文件 @
975e0d2b
...
@@ -515,8 +515,8 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
...
@@ -515,8 +515,8 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self
.
assertRaises
(
IndexError
,
g
,
shp
)
self
.
assertRaises
(
IndexError
,
g
,
shp
)
def
test_adv_sub1_broadcast
(
self
):
def
test_adv_sub1_broadcast
(
self
):
ones
=
numpy
.
ones
((
1
,
3
),
dtype
=
self
.
dtype
)
v
=
numpy
.
arange
(
3
,
dtype
=
self
.
dtype
)
.
reshape
((
1
,
3
)
)
n
=
self
.
shared
(
ones
*
5
,
broadcastable
=
(
True
,
False
))
n
=
self
.
shared
(
v
*
5
,
broadcastable
=
(
True
,
False
))
idx
=
tensor
.
lvector
()
idx
=
tensor
.
lvector
()
t
=
n
[
idx
]
t
=
n
[
idx
]
self
.
assertTrue
(
isinstance
(
t
.
owner
.
op
,
tensor
.
AdvancedSubtensor1
))
self
.
assertTrue
(
isinstance
(
t
.
owner
.
op
,
tensor
.
AdvancedSubtensor1
))
...
@@ -529,10 +529,10 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
...
@@ -529,10 +529,10 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self
.
assertTrue
(
isinstance
(
topo_
[
0
]
.
op
,
self
.
adv_sub1
))
self
.
assertTrue
(
isinstance
(
topo_
[
0
]
.
op
,
self
.
adv_sub1
))
f_0
=
f
([
0
])
f_0
=
f
([
0
])
self
.
assertTrue
(
f_0
.
shape
==
(
1
,
3
))
self
.
assertTrue
(
f_0
.
shape
==
(
1
,
3
))
self
.
assertTrue
(
numpy
.
allclose
(
f_0
,
ones
[
0
]
*
5
))
self
.
assertTrue
(
numpy
.
allclose
(
f_0
,
v
*
5
))
f_00
=
f
([
0
,
0
])
f_00
=
f
([
0
,
0
])
self
.
assertTrue
(
f_00
.
shape
==
(
2
,
3
))
self
.
assertTrue
(
f_00
.
shape
==
(
2
,
3
))
self
.
assertTrue
(
numpy
.
allclose
(
f_00
,
5
))
self
.
assertTrue
(
numpy
.
allclose
(
f_00
,
v
*
5
))
self
.
assertRaises
(
IndexError
,
f
,
[
0
,
1
])
self
.
assertRaises
(
IndexError
,
f
,
[
0
,
1
])
# Test the gradient
# Test the gradient
...
...
theano/tests/test_flake8.py
浏览文件 @
975e0d2b
...
@@ -160,7 +160,6 @@ whitelist_flake8 = [
...
@@ -160,7 +160,6 @@ whitelist_flake8 = [
"sandbox/linalg/tests/test_linalg.py"
,
"sandbox/linalg/tests/test_linalg.py"
,
"sandbox/gpuarray/basic_ops.py"
,
"sandbox/gpuarray/basic_ops.py"
,
"sandbox/gpuarray/nnet.py"
,
"sandbox/gpuarray/nnet.py"
,
"sandbox/gpuarray/subtensor.py"
,
"sandbox/gpuarray/elemwise.py"
,
"sandbox/gpuarray/elemwise.py"
,
"sandbox/gpuarray/type.py"
,
"sandbox/gpuarray/type.py"
,
"sandbox/gpuarray/__init__.py"
,
"sandbox/gpuarray/__init__.py"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论