Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
25823ba2
提交
25823ba2
authored
2月 13, 2014
作者:
Pierre Luc Carrier
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make GpuSoftmax and GpuSoftmaxWithBias compatible with float64 and adjust unit tests to test this
上级
d58444e1
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
72 行增加
和
26 行删除
+72
-26
kernel_codegen.py
theano/sandbox/gpuarray/kernel_codegen.py
+18
-14
nnet.py
theano/sandbox/gpuarray/nnet.py
+0
-0
test_nnet.py
theano/sandbox/gpuarray/tests/test_nnet.py
+54
-12
没有找到文件。
theano/sandbox/gpuarray/kernel_codegen.py
浏览文件 @
25823ba2
...
...
@@ -124,12 +124,13 @@ def inline_reduce_prod(N, buf, pos, count):
@code_version
((
2
,)
+
inline_reduce_max
.
code_version
+
inline_reduce_sum
.
code_version
)
def
inline_softmax
(
N
,
buf
,
buf2
,
threadPos
,
threadCount
):
def
inline_softmax
(
N
,
buf
,
buf2
,
threadPos
,
threadCount
,
dtype
=
"float32"
):
"""
:param N: length of the buffer
:param threadPos: index of executing thread
:param threadCount: number of executing threads
:param dtype: dtype of the softmax's output
:Precondition: buf and buf2 contain two identical copies of the input
to softmax
...
...
@@ -144,7 +145,7 @@ def inline_softmax(N, buf, buf2, threadPos, threadCount):
#get max of buf (trashing all but buf[0])
inline_reduce_max
(
N
,
buf
,
threadPos
,
threadCount
),
'__syncthreads()'
,
'float row_max = '
+
buf
+
'[0]'
,
(
'npy_
%
s row_max = '
+
buf
+
'[0]'
)
%
dtype
,
'__syncthreads()'
,
'for(int __i='
+
threadPos
+
'; __i<'
+
N
+
'; __i+='
+
threadCount
+
'){'
,
...
...
@@ -154,7 +155,7 @@ def inline_softmax(N, buf, buf2, threadPos, threadCount):
'__syncthreads()'
,
inline_reduce_sum
(
N
,
buf
,
threadPos
,
threadCount
),
'__syncthreads()'
,
'float row_sum = '
+
buf
+
'[0]'
,
(
'npy_
%
s row_sum = '
+
buf
+
'[0]'
)
%
dtype
,
'__syncthreads()'
,
# divide each exp() result by the sum to complete the job.
'for(int __i='
+
threadPos
+
'; __i<'
+
N
+
...
...
@@ -168,15 +169,16 @@ def inline_softmax(N, buf, buf2, threadPos, threadCount):
@code_version
((
1
,))
def
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
manner_fn
,
manner_init
,
b
=
''
,
stride_b
=
''
):
b
=
''
,
stride_b
=
''
,
dtype
=
'float32'
):
"""Return C++ code for a function that reduces a contiguous buffer.
:param N: length of the buffer
:param buf: buffer pointer of size warpSize * sizeof(
float
)
:param buf: buffer pointer of size warpSize * sizeof(
dtype
)
:param pos: index of executing thread
:param count: number of executing threads
:param b: Optional, pointer to the bias
:param stride_b: Optional, the stride of b if b is provided
:param dtype: Optional, the dtype of the output
:param manner_fn: a function that accepts strings of arguments a
and b, and returns c code for their reduction. (Example:
...
...
@@ -214,7 +216,7 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
{
// This function trashes buf[1..n_threads],
// leaving the reduction result in buf[0].
float
red =
%(init)
s;
npy_
%(dtype)
s
red =
%(init)
s;
#pragma unroll 16
for (int i =
%(pos)
s +
%(count)
s; i<
%(N)
s; i +=
%(count)
s){
red =
%(loop_line)
s;
...
...
@@ -248,11 +250,11 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
@code_version
(
inline_reduce_fixed_shared
.
code_version
)
def
inline_reduce_fixed_shared_max
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
b
=
''
,
stride_b
=
''
):
b
=
''
,
stride_b
=
''
,
dtype
=
'float32'
):
return
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
lambda
a
,
b
:
"max(
%
s,
%
s)"
%
(
a
,
b
),
lambda
a
:
a
,
b
,
stride_b
)
b
,
stride_b
,
dtype
)
@code_version
((
1
,)
+
inline_reduce_max
.
code_version
+
...
...
@@ -260,11 +262,11 @@ def inline_reduce_fixed_shared_max(N, buf, x, stride_x, pos, count,
def
inline_softmax_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
sm
,
sm_stride
,
threadPos
,
threadCount
,
b
=
''
,
stride_b
=
''
):
b
=
''
,
stride_b
=
''
,
dtype
=
"float32"
):
"""
:param N: length of the buffer, atleast waprSize(32).
:param buf: a shared memory buffer of size warpSize * sizeof(
float
)
:param buf: a shared memory buffer of size warpSize * sizeof(
dtype
)
:param x: a ptr to the gpu memory where the row is stored
:param stride_x: the stride between each element in x
:param sm: a ptr to the gpu memory to store the result
...
...
@@ -273,6 +275,7 @@ def inline_softmax_fixed_shared(N, buf, x, stride_x,
:param threadCount: number of executing threads
:param b: Optional, pointer to the bias
:param stride_b: Optional, the stride of b if b is provided
:param dtype: Optional, the dtype of the softmax's output if not float32
:Precondition: buf is empty
:Postcondition: buf[0] contains the softmax,
...
...
@@ -285,16 +288,17 @@ def inline_softmax_fixed_shared(N, buf, x, stride_x,
ret
=
[
#get max of buf (trashing all but buf[0])
inline_reduce_fixed_shared_max
(
N
,
buf
,
x
,
stride_x
,
threadPos
,
threadCount
,
b
,
stride_b
),
threadPos
,
threadCount
,
b
,
stride_b
,
dtype
),
'__syncthreads()'
,
'float row_max = '
+
buf
+
'[0]'
,
(
'npy_
%
s row_max = '
+
buf
+
'[0]'
)
%
dtype
,
'__syncthreads()'
,
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
threadPos
,
threadCount
,
lambda
a
,
b
:
"
%
s +
%
s"
%
(
a
,
b
),
lambda
a
:
"exp(
%
s - row_max)"
%
a
,
b
,
stride_b
),
b
,
stride_b
,
dtype
),
'__syncthreads()'
,
'float row_sum = '
+
buf
+
'[0]'
,
(
'npy_
%
s row_sum = '
+
buf
+
'[0]'
)
%
dtype
,
'__syncthreads()'
,
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
]
...
...
theano/sandbox/gpuarray/nnet.py
浏览文件 @
25823ba2
差异被折叠。
点击展开。
theano/sandbox/gpuarray/tests/test_nnet.py
浏览文件 @
25823ba2
...
...
@@ -159,20 +159,44 @@ def test_GpuCrossentropySoftmax1HotWithBiasDx():
rtol
,
atol
)
def
test_softmax_with_bias
():
def
test_softmax_with_bias_float32
():
softmax_with_bias_unittest_template
(
dtypeInput
=
'float32'
,
dtypeBias
=
'float32'
)
def
test_softmax_with_bias_float64
():
softmax_with_bias_unittest_template
(
dtypeInput
=
'float32'
,
dtypeBias
=
'float64'
)
softmax_with_bias_unittest_template
(
dtypeInput
=
'float64'
,
dtypeBias
=
'float32'
)
softmax_with_bias_unittest_template
(
dtypeInput
=
'float64'
,
dtypeBias
=
'float64'
)
def
softmax_with_bias_unittest_template
(
dtypeInput
,
dtypeBias
):
"""
This is basic test for GpuSoftmaxWithBias
This is basic test for GpuSoftmaxWithBias
with float64 variables
We check that we loop when their is too much block
TODO: check that we loop when their is too much thread.(THIS IS
NOT IMPLEMENTED)
"""
x
=
T
.
fmatrix
(
'x'
)
assert
dtypeInput
in
[
'float32'
,
'float64'
]
assert
dtypeBias
in
[
'float32'
,
'float64'
]
if
dtypeInput
==
'float32'
:
x
=
T
.
fmatrix
(
'x'
)
elif
dtypeInput
==
'float64'
:
x
=
T
.
dmatrix
(
'x'
)
# We can't use zeros_like(x[0,::]) as this don't allow to test with
# 0 shape.
z
=
T
.
nnet
.
softmax_with_bias
(
x
,
T
.
arange
(
x
.
shape
[
1
]
*
2
,
dtype
=
'float32'
)[::
2
])
# 0 shape
if
dtypeBias
==
'float32'
:
z
=
T
.
nnet
.
softmax_with_bias
(
x
,
T
.
arange
(
x
.
shape
[
1
]
*
2
,
dtype
=
'float32'
)[::
2
])
elif
dtypeBias
==
'float64'
:
z
=
T
.
nnet
.
softmax_with_bias
(
x
,
T
.
arange
(
x
.
shape
[
1
]
*
2
,
dtype
=
'float64'
)[::
2
])
f
=
theano
.
function
([
x
],
z
,
mode
=
mode_without_gpu
)
f_gpu
=
theano
.
function
([
x
],
z
,
mode
=
mode_with_gpu
)
...
...
@@ -182,7 +206,11 @@ def test_softmax_with_bias():
def
cmp
(
n
,
m
):
#print "test_softmax",n,m
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float32'
)
.
reshape
(
n
,
m
)
if
dtypeInput
==
'float32'
:
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float32'
)
.
reshape
(
n
,
m
)
elif
dtypeInput
==
'float64'
:
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float64'
)
.
reshape
(
n
,
m
)
out
=
f
(
data
)
gout
=
f_gpu
(
data
)
assert
numpy
.
allclose
(
out
,
gout
),
numpy
.
absolute
(
out
-
gout
)
...
...
@@ -205,14 +233,25 @@ def test_softmax_with_bias():
cmp
(
128
,
64
*
1024
)
def
test_softmax
():
def
test_softmax_float32
():
softmax_unittest_template
(
'float32'
)
def
test_softmax_float64
():
softmax_unittest_template
(
'float32'
)
def
softmax_unittest_template
(
dtypeInput
):
"""
This is basic test for GpuSoftmax
This is basic test for GpuSoftmax
with float64 variables
We check that we loop when their is too much block
We use slower code when there isn't enough shared memory
"""
x
=
T
.
fmatrix
(
'x'
)
assert
dtypeInput
in
[
'float32'
,
'float64'
]
if
dtypeInput
==
'float32'
:
x
=
T
.
fmatrix
(
'x'
)
elif
dtypeInput
==
'float64'
:
x
=
T
.
dmatrix
(
'x'
)
z
=
T
.
nnet
.
softmax
(
x
)
f
=
theano
.
function
([
x
],
z
,
mode
=
mode_without_gpu
)
...
...
@@ -222,8 +261,11 @@ def test_softmax():
theano
.
sandbox
.
gpuarray
.
nnet
.
GpuSoftmax
)
def
cmp
(
n
,
m
):
#print "test_softmax",n,m
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float32'
)
.
reshape
(
n
,
m
)
if
dtypeInput
==
'float32'
:
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float32'
)
.
reshape
(
n
,
m
)
elif
dtypeInput
==
'float64'
:
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float64'
)
.
reshape
(
n
,
m
)
out
=
f
(
data
)
gout
=
f_gpu
(
data
)
assert
numpy
.
allclose
(
out
,
gout
),
numpy
.
absolute
(
out
-
gout
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论