Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
024347f9
提交
024347f9
authored
11月 29, 2013
作者:
Vincent Dumoulin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Partially fix failing tests
上级
b922c47c
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
72 行增加
和
27 行删除
+72
-27
nnet.py
theano/sandbox/gpuarray/nnet.py
+34
-21
opt.py
theano/sandbox/gpuarray/opt.py
+15
-0
test_nnet.py
theano/sandbox/gpuarray/tests/test_nnet.py
+23
-6
没有找到文件。
theano/sandbox/gpuarray/nnet.py
浏览文件 @
024347f9
...
@@ -12,6 +12,8 @@ try:
...
@@ -12,6 +12,8 @@ try:
except
ImportError
:
except
ImportError
:
pass
pass
from
theano.sandbox.gpuarray.basic_ops
import
as_gpuarray_variable
class
GpuCrossentropySoftmaxArgmax1HotWithBias
(
Op
):
class
GpuCrossentropySoftmaxArgmax1HotWithBias
(
Op
):
"""
"""
...
@@ -31,6 +33,9 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
...
@@ -31,6 +33,9 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
def
make_node
(
self
,
x
,
b
,
y_idx
):
def
make_node
(
self
,
x
,
b
,
y_idx
):
#N.B. won't work when we don't cast y_idx to float anymore
#N.B. won't work when we don't cast y_idx to float anymore
x
=
as_gpuarray_variable
(
x
)
b
=
as_gpuarray_variable
(
b
)
y_idx
=
as_gpuarray_variable
(
y_idx
)
nll
=
y_idx
.
type
()
nll
=
y_idx
.
type
()
sm
=
x
.
type
()
sm
=
x
.
type
()
am
=
y_idx
.
type
()
am
=
y_idx
.
type
()
...
@@ -39,29 +44,31 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
...
@@ -39,29 +44,31 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
def
c_headers
(
self
):
def
c_headers
(
self
):
return
[
'cuda.h'
,
'<compyte/extension.h>'
,
'<compyte/numpy_compat.h>'
]
return
[
'cuda.h'
,
'<compyte/extension.h>'
,
'<compyte/numpy_compat.h>'
]
def
c_support_code
(
self
):
def
c_support_code_apply
(
self
,
node
):
dtype
=
self
.
dtype
dtype0
=
node
.
inputs
[
0
]
.
dtype
dtype1
=
node
.
inputs
[
1
]
.
dtype
dtype2
=
node
.
inputs
[
2
]
.
dtype
return
"""
return
"""
__global__ void k_xent_sm_1hot_bias(int M, int N,
__global__ void k_xent_sm_1hot_bias(int M, int N,
const npy_
%(dtype)
s* x_data, int xs0, int xs1,
const npy_
%(dtype
0
)
s* x_data, int xs0, int xs1,
const npy_
%(dtype)
s* b, int bs0,
const npy_
%(dtype
1
)
s* b, int bs0,
const npy_
%(dtype)
s* y_idx_data, int y_idxs0,
const npy_
%(dtype
2
)
s* y_idx_data, int y_idxs0,
npy_
%(dtype)
s* nll_data, int nlls0,
npy_
%(dtype)
s* nll_data, int nlls0,
npy_
%(dtype)
s* sm_data, int sms0, int sms1,
npy_
%(dtype)
s* sm_data, int sms0, int sms1,
npy_
%(dtype)
s* am_data, int ams0)
npy_
%(dtype)
s* am_data, int ams0)
{
{
for (int row = blockIdx.x; row < M; row += gridDim.x){
for (int row = blockIdx.x; row < M; row += gridDim.x){
const npy_
%(dtype)
s* x = x_data + xs0 * row;
const npy_
%(dtype
0
)
s* x = x_data + xs0 * row;
const int y_idx = (int)y_idx_data[row * y_idxs0];
const int y_idx = (int)y_idx_data[row * y_idxs0];
npy_
%(dtype)
s* sm = sm_data + sms0 * row;
npy_
%(dtype
0
)
s* sm = sm_data + sms0 * row;
npy_
%(dtype)
s sum = 0.0;
npy_
%(dtype
0
)
s sum = 0.0;
int row_max_j = 0;
int row_max_j = 0;
npy_
%(dtype)
s row_max = x[0] + b[0];
npy_
%(dtype
0
)
s row_max = x[0] + b[0];
for (int j = 1; j < N; ++j)
for (int j = 1; j < N; ++j)
{
{
npy_
%(dtype)
s row_ij = x[j*xs1] + b[j*bs0];
npy_
%(dtype
0
)
s row_ij = x[j*xs1] + b[j*bs0];
//todo: store to shared memory
//todo: store to shared memory
row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max = (row_ij > row_max) ? row_ij : row_max;
row_max = (row_ij > row_max) ? row_ij : row_max;
...
@@ -69,12 +76,12 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
...
@@ -69,12 +76,12 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
//compute the exp
//compute the exp
for (int j = 0; j < N; ++j)
for (int j = 0; j < N; ++j)
{
{
npy_
%(dtype)
s row_ij = x[j*xs1] + b[j*bs0];
npy_
%(dtype
0
)
s row_ij = x[j*xs1] + b[j*bs0];
npy_
%(dtype)
s sm_ij = exp(row_ij - row_max);
npy_
%(dtype
0
)
s sm_ij = exp(row_ij - row_max);
sum += sm_ij;
sum += sm_ij;
sm[j * sms1] = sm_ij;
sm[j * sms1] = sm_ij;
}
}
npy_
%(dtype)
s sum_inv = 1.0 / sum;
npy_
%(dtype
0
)
s sum_inv = 1.0 / sum;
for (int j = 0; j < N; ++j)
for (int j = 0; j < N; ++j)
{
{
sm[j * sms1] *= sum_inv;
sm[j * sms1] *= sum_inv;
...
@@ -190,7 +197,7 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
...
@@ -190,7 +197,7 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
}
}
{
{
int n_blocks = std::min(PyGpuArray_DIMS(
%(x)
s)[0],
int n_blocks = std::min(PyGpuArray_DIMS(
%(x)
s)[0],
NUM_VECTOR_OP_BLOCKS
);
256
);
//TODO: launch more threads per row and do parallel sum and max reductions
//TODO: launch more threads per row and do parallel sum and max reductions
int n_threads = 1;
int n_threads = 1;
int n_shared_bytes = 0; //n_threads * sizeof(
%(dtype)
s);
int n_shared_bytes = 0; //n_threads * sizeof(
%(dtype)
s);
...
@@ -267,6 +274,9 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
...
@@ -267,6 +274,9 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
def
make_node
(
self
,
dy
,
sm
,
y_idx
):
def
make_node
(
self
,
dy
,
sm
,
y_idx
):
dy
=
as_gpuarray_variable
(
dy
)
sm
=
as_gpuarray_variable
(
sm
)
y_idx
=
as_gpuarray_variable
(
y_idx
)
return
Apply
(
self
,
[
dy
,
sm
,
y_idx
],
[
sm
.
type
()])
return
Apply
(
self
,
[
dy
,
sm
,
y_idx
],
[
sm
.
type
()])
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
@@ -280,6 +290,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
...
@@ -280,6 +290,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
return
NVCC_compiler
return
NVCC_compiler
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
typecode
=
pygpu
.
gpuarray
.
dtype_to_typecode
(
node
.
outputs
[
0
]
.
dtype
)
dnll
,
sm
,
y_idx
=
inp
dnll
,
sm
,
y_idx
=
inp
dx
,
=
out
dx
,
=
out
fail
=
sub
[
'fail'
]
fail
=
sub
[
'fail'
]
...
@@ -324,7 +335,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
...
@@ -324,7 +335,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
}
}
{
{
int n_blocks = std::min(PyGpuArray_DIMS(
%(dx)
s)[0],
int n_blocks = std::min(PyGpuArray_DIMS(
%(dx)
s)[0],
NUM_VECTOR_OP_BLOCKS
);
256
);
int n_threads = std::min(PyGpuArray_DIMS(
%(dx)
s)[1],256);
int n_threads = std::min(PyGpuArray_DIMS(
%(dx)
s)[1],256);
kCrossEntropySoftmax1HotWithBiasDx_
%(nodename)
s
kCrossEntropySoftmax1HotWithBiasDx_
%(nodename)
s
...
@@ -367,18 +378,20 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
...
@@ -367,18 +378,20 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
"""
%
locals
()
"""
%
locals
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
dtype
=
self
.
dtype
dtype0
=
node
.
inputs
[
0
]
.
dtype
dtype1
=
node
.
inputs
[
1
]
.
dtype
dtype2
=
node
.
inputs
[
2
]
.
dtype
return
"""
return
"""
__global__ void kCrossEntropySoftmax1HotWithBiasDx_
%(nodename)
s(
__global__ void kCrossEntropySoftmax1HotWithBiasDx_
%(nodename)
s(
int N, int K,
int N, int K,
const npy_
%(dtype)
s* dnll, const int dnll_s0,
const npy_
%(dtype
0
)
s* dnll, const int dnll_s0,
const npy_
%(dtype)
s* sm, const int sm_s0, const int sm_s1,
const npy_
%(dtype
1
)
s* sm, const int sm_s0, const int sm_s1,
const npy_
%(dtype)
s* y_idx, const int y_idx_s0,
const npy_
%(dtype
2
)
s* y_idx, const int y_idx_s0,
npy_
%(dtype)
s* dx, const int dx_s0, const int dx_s1)
npy_
%(dtype
1
)
s* dx, const int dx_s0, const int dx_s1)
{
{
for (int i = blockIdx.x; i < N; i += gridDim.x)
for (int i = blockIdx.x; i < N; i += gridDim.x)
{
{
npy_
%(dtype)
s dnll_i = dnll[i * dnll_s0];
npy_
%(dtype
0
)
s dnll_i = dnll[i * dnll_s0];
int y_i = (int)y_idx[i * y_idx_s0];
int y_i = (int)y_idx[i * y_idx_s0];
for (int j = threadIdx.x; j < K; j += blockDim.x)
for (int j = threadIdx.x; j < K; j += blockDim.x)
...
...
theano/sandbox/gpuarray/opt.py
浏览文件 @
024347f9
...
@@ -18,6 +18,8 @@ from theano.sandbox.gpuarray.basic_ops import (host_from_gpu,
...
@@ -18,6 +18,8 @@ from theano.sandbox.gpuarray.basic_ops import (host_from_gpu,
GpuReshape
,
GpuReshape
,
GpuEye
)
GpuEye
)
from
theano.sandbox.gpuarray.blas
import
gpu_dot22
,
GpuGemv
,
GpuGemm
from
theano.sandbox.gpuarray.blas
import
gpu_dot22
,
GpuGemv
,
GpuGemm
from
theano.sandbox.gpuarray.nnet
import
(
GpuCrossentropySoftmaxArgmax1HotWithBias
,
GpuCrossentropySoftmax1HotWithBiasDx
)
from
theano.sandbox.gpuarray.elemwise
import
(
GpuElemwise
,
_is_scalar
,
from
theano.sandbox.gpuarray.elemwise
import
(
GpuElemwise
,
_is_scalar
,
GpuDimShuffle
,
GpuCAReduce
)
GpuDimShuffle
,
GpuCAReduce
)
from
theano.sandbox.gpuarray.subtensor
import
GpuSubtensor
from
theano.sandbox.gpuarray.subtensor
import
GpuSubtensor
...
@@ -267,3 +269,16 @@ def local_gpua_dot22(node):
...
@@ -267,3 +269,16 @@ def local_gpua_dot22(node):
@op_lifter
([
tensor
.
basic
.
Eye
])
@op_lifter
([
tensor
.
basic
.
Eye
])
def
local_gpua_eye
(
node
):
def
local_gpua_eye
(
node
):
return
GpuEye
(
dtype
=
node
.
op
.
dtype
)
return
GpuEye
(
dtype
=
node
.
op
.
dtype
)
@register_opt
()
@op_lifter
([
tensor
.
nnet
.
CrossentropySoftmaxArgmax1HotWithBias
])
def
local_gpua_crossentropysoftmaxargmax1hotwithbias
(
node
):
return
GpuCrossentropySoftmaxArgmax1HotWithBias
()
@register_opt
()
@op_lifter
([
tensor
.
nnet
.
CrossentropySoftmax1HotWithBiasDx
])
def
local_gpua_crossentropysoftmax1hotwithbiasdx
(
node
):
return
GpuCrossentropySoftmax1HotWithBiasDx
()
theano/sandbox/gpuarray/tests/test_nnet.py
浏览文件 @
024347f9
...
@@ -6,10 +6,27 @@ from theano.gof.python25 import any
...
@@ -6,10 +6,27 @@ from theano.gof.python25 import any
import
theano.tensor
as
T
import
theano.tensor
as
T
import
theano.tests.unittest_tools
as
utt
import
theano.tests.unittest_tools
as
utt
# Skip test if cuda_ndarray is not available.
import
theano.sandbox.gpuarray
import
theano.sandbox.cuda
as
cuda
if
cuda
.
cuda_available
==
False
:
if
theano
.
sandbox
.
gpuarray
.
pygpu
is
None
:
raise
SkipTest
(
'Optional package cuda disabled'
)
raise
SkipTest
(
"pygpu not installed"
)
import
theano.sandbox.cuda
as
cuda_ndarray
if
cuda_ndarray
.
cuda_available
and
not
theano
.
sandbox
.
gpuarray
.
pygpu_activated
:
if
not
cuda_ndarray
.
use
.
device_number
:
#We should not enable all the use like the flag device=gpu,
#as many tests don't work in that setup.
cuda_ndarray
.
use
(
'gpu'
,
default_to_move_computation_to_gpu
=
False
,
move_shared_float32_to_gpu
=
False
,
enable_cuda
=
False
)
theano
.
sandbox
.
gpuarray
.
init_dev
(
'cuda'
)
if
not
theano
.
sandbox
.
gpuarray
.
pygpu_activated
:
raise
SkipTest
(
"pygpu disabled"
)
from
theano.sandbox.gpuarray.nnet
import
(
GpuCrossentropySoftmaxArgmax1HotWithBias
,
GpuCrossentropySoftmax1HotWithBiasDx
)
if
theano
.
config
.
mode
==
'FAST_COMPILE'
:
if
theano
.
config
.
mode
==
'FAST_COMPILE'
:
mode_with_gpu
=
theano
.
compile
.
mode
.
get_mode
(
'FAST_RUN'
)
.
including
(
'gpu'
)
mode_with_gpu
=
theano
.
compile
.
mode
.
get_mode
(
'FAST_RUN'
)
.
including
(
'gpu'
)
...
@@ -78,7 +95,7 @@ def test_GpuCrossentropySoftmaxArgmax1HotWithBias():
...
@@ -78,7 +95,7 @@ def test_GpuCrossentropySoftmaxArgmax1HotWithBias():
T
.
nnet
.
CrossentropySoftmaxArgmax1HotWithBias
)
T
.
nnet
.
CrossentropySoftmaxArgmax1HotWithBias
)
for
node
in
classify
.
maker
.
fgraph
.
toposort
()])
for
node
in
classify
.
maker
.
fgraph
.
toposort
()])
assert
any
([
isinstance
(
node
.
op
,
assert
any
([
isinstance
(
node
.
op
,
cuda
.
nnet
.
GpuCrossentropySoftmaxArgmax1HotWithBias
)
theano
.
sandbox
.
gpuarray
.
nnet
.
GpuCrossentropySoftmaxArgmax1HotWithBias
)
for
node
in
classify_gpu
.
maker
.
fgraph
.
toposort
()])
for
node
in
classify_gpu
.
maker
.
fgraph
.
toposort
()])
out
=
classify
(
yy
,
b_values
,
dot_value
)
out
=
classify
(
yy
,
b_values
,
dot_value
)
...
@@ -133,7 +150,7 @@ def test_GpuCrossentropySoftmax1HotWithBiasDx():
...
@@ -133,7 +150,7 @@ def test_GpuCrossentropySoftmax1HotWithBiasDx():
assert
any
([
isinstance
(
node
.
op
,
T
.
nnet
.
CrossentropySoftmax1HotWithBiasDx
)
assert
any
([
isinstance
(
node
.
op
,
T
.
nnet
.
CrossentropySoftmax1HotWithBiasDx
)
for
node
in
cpu_f
.
maker
.
fgraph
.
toposort
()])
for
node
in
cpu_f
.
maker
.
fgraph
.
toposort
()])
assert
any
([
isinstance
(
node
.
op
,
assert
any
([
isinstance
(
node
.
op
,
cuda
.
nnet
.
GpuCrossentropySoftmax1HotWithBiasDx
)
theano
.
sandbox
.
gpuarray
.
nnet
.
GpuCrossentropySoftmax1HotWithBiasDx
)
for
node
in
gpu_f
.
maker
.
fgraph
.
toposort
()])
for
node
in
gpu_f
.
maker
.
fgraph
.
toposort
()])
cpu_out
=
cpu_f
(
softmax_output_value
)
cpu_out
=
cpu_f
(
softmax_output_value
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论