Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
abd0a0fc
提交
abd0a0fc
authored
4月 16, 2015
作者:
Arnaud Bergeron
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add support for float16 to one of the versions of GpuCrossentropy...
上级
cc7c365c
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
41 行增加
和
13 行删除
+41
-13
fp16_help.py
theano/sandbox/gpuarray/fp16_help.py
+18
-0
nnet.py
theano/sandbox/gpuarray/nnet.py
+23
-13
没有找到文件。
theano/sandbox/gpuarray/fp16_help.py
0 → 100644
浏览文件 @
abd0a0fc
def
work_dtype
(
dtype
):
if
dtype
==
'float16'
:
return
'float32'
else
:
return
dtype
def
load_w
(
dtype
):
if
dtype
==
'float16'
:
return
'__half2float'
else
:
return
''
def
write_w
(
dtype
):
if
dtype
==
'float16'
:
return
'__float2half_rn'
else
:
return
''
theano/sandbox/gpuarray/nnet.py
浏览文件 @
abd0a0fc
...
@@ -16,6 +16,7 @@ from .type import GpuArrayType
...
@@ -16,6 +16,7 @@ from .type import GpuArrayType
from
.kernel_codegen
import
(
nvcc_kernel
,
from
.kernel_codegen
import
(
nvcc_kernel
,
inline_softmax
,
inline_softmax
,
inline_softmax_fixed_shared
)
inline_softmax_fixed_shared
)
from
.fp16_help
import
work_dtype
,
load_w
,
write_w
class
GpuCrossentropySoftmaxArgmax1HotWithBias
(
Op
):
class
GpuCrossentropySoftmaxArgmax1HotWithBias
(
Op
):
...
@@ -52,6 +53,12 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
...
@@ -52,6 +53,12 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
dtype_x
=
node
.
inputs
[
0
]
.
dtype
dtype_x
=
node
.
inputs
[
0
]
.
dtype
dtype_b
=
node
.
inputs
[
1
]
.
dtype
dtype_b
=
node
.
inputs
[
1
]
.
dtype
dtype_y_idx
=
node
.
inputs
[
2
]
.
dtype
dtype_y_idx
=
node
.
inputs
[
2
]
.
dtype
work_x
=
work_dtype
(
dtype_x
)
work_b
=
work_dtype
(
dtype_b
)
load_x
=
load_w
(
dtype_x
)
load_b
=
load_w
(
dtype_b
)
write_x
=
write_w
(
dtype_x
)
write_b
=
write_w
(
dtype_b
)
return
"""
return
"""
__global__ void k_xent_sm_1hot_bias_
%(nodename)
s(int M, int N,
__global__ void k_xent_sm_1hot_bias_
%(nodename)
s(int M, int N,
const npy_
%(dtype_x)
s* x_data, int xs0, int xs1,
const npy_
%(dtype_x)
s* x_data, int xs0, int xs1,
...
@@ -67,12 +74,13 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
...
@@ -67,12 +74,13 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
const npy_
%(dtype_y_idx)
s y_idx = y_idx_data[row * y_idxs0];
const npy_
%(dtype_y_idx)
s y_idx = y_idx_data[row * y_idxs0];
npy_
%(dtype_x)
s* sm = sm_data + sms0 * row;
npy_
%(dtype_x)
s* sm = sm_data + sms0 * row;
npy_
%(
dtype
_x)
s sum = 0.0;
npy_
%(
work
_x)
s sum = 0.0;
int row_max_j = 0;
int row_max_j = 0;
npy_
%(
dtype_x)
s row_max = x[0] + b[0]
;
npy_
%(
work_x)
s row_max =
%(load_x)
s(x[0]) +
%(load_b)
s(b[0])
;
for (int j = 1; j < N; ++j)
for (int j = 1; j < N; ++j)
{
{
npy_
%(dtype_x)
s row_ij = x[j*xs1] + b[j*bs0];
npy_
%(work_x)
s row_ij =
%(load_x)
s(x[j*xs1]) +
%(load_b)
s(b[j*bs0]);
//todo: store to shared memory
//todo: store to shared memory
row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max = (row_ij > row_max) ? row_ij : row_max;
row_max = (row_ij > row_max) ? row_ij : row_max;
...
@@ -80,27 +88,30 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
...
@@ -80,27 +88,30 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
//compute the exp
//compute the exp
for (int j = 0; j < N; ++j)
for (int j = 0; j < N; ++j)
{
{
npy_
%(dtype_x)
s row_ij = x[j*xs1] + b[j*bs0];
npy_
%(work_x)
s row_ij =
%(load_x)
s(x[j*xs1]) +
npy_
%(dtype_x)
s sm_ij = exp(row_ij - row_max);
%(load_b)
s(b[j*bs0]);
npy_
%(work_x)
s sm_ij = exp(row_ij - row_max);
sum += sm_ij;
sum += sm_ij;
sm[j * sms1] =
sm_ij
;
sm[j * sms1] =
%(write_x)
s(sm_ij)
;
}
}
npy_
%(
dtype
_x)
s sum_inv = 1.0 / sum;
npy_
%(
work
_x)
s sum_inv = 1.0 / sum;
for (int j = 0; j < N; ++j)
for (int j = 0; j < N; ++j)
{
{
sm[j * sms1] *= sum_inv;
npy_
%(work_x)
s __tmp =
%(load_x)
s(sm[j * sms1]);
__tmp *= sum_inv;
sm[j * sms1] =
%(write_x)
s(__tmp);
}
}
if ((y_idx >= N) || (y_idx < 0))
if ((y_idx >= N) || (y_idx < 0))
{
{
//TODO: set raise an error bit in a global var?
//TODO: set raise an error bit in a global var?
nll_data[row*nlls0] =
0.0
; // raise some suspicion at least...
nll_data[row*nlls0] =
%(write_x)
s(0.0)
; // raise some suspicion at least...
}
}
else
else
{
{
nll_data[row*nlls0] =
- x[y_idx*xs1]
nll_data[row*nlls0] =
%(write_x)
s(-
%(load_x)
s(x[y_idx*xs1])
-
b[y_idx*bs0]
-
%(load_b)
s(b[y_idx*bs0])
+ row_max
+ row_max
+ log(sum);
+ log(sum)
)
;
}
}
am_data[row*ams0] = row_max_j;
am_data[row*ams0] = row_max_j;
}
}
...
@@ -259,7 +270,6 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
...
@@ -259,7 +270,6 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
return
sio
.
getvalue
()
return
sio
.
getvalue
()
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
# return ()
return
(
5
,)
return
(
5
,)
def
c_compiler
(
self
):
def
c_compiler
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论