Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5b6f228e
提交
5b6f228e
authored
3月 28, 2013
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make GpuSoftmaxWithBias work with bigger input.
Also in kernel_codegen.py There is a fix then X is strided. The variable loop_line wasn't checking the stride.
上级
bbd6a10e
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
94 行增加
和
43 行删除
+94
-43
kernel_codegen.py
theano/sandbox/cuda/kernel_codegen.py
+38
-13
nnet.py
theano/sandbox/cuda/nnet.py
+43
-4
test_nnet.py
theano/sandbox/cuda/tests/test_nnet.py
+13
-26
没有找到文件。
theano/sandbox/cuda/kernel_codegen.py
浏览文件 @
5b6f228e
...
@@ -162,13 +162,16 @@ def inline_softmax(N, buf, buf2, threadPos, threadCount):
...
@@ -162,13 +162,16 @@ def inline_softmax(N, buf, buf2, threadPos, threadCount):
@code_version
((
1
,))
@code_version
((
1
,))
def
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
def
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
manner_fn
,
manner_init
):
manner_fn
,
manner_init
,
b
=
''
,
stride_b
=
''
):
"""Return C++ code for a function that reduces a contiguous buffer.
"""Return C++ code for a function that reduces a contiguous buffer.
:param N: length of the buffer
:param N: length of the buffer
:param buf: buffer pointer of size warpSize * sizeof(float)
:param buf: buffer pointer of size warpSize * sizeof(float)
:param pos: index of executing thread
:param pos: index of executing thread
:param count: number of executing threads
:param count: number of executing threads
:param b: Optional, pointer to the bias
:param stride_b: Optional, the stride of b if b is provided
:param manner_fn: a function that accepts strings of arguments a
:param manner_fn: a function that accepts strings of arguments a
and b, and returns c code for their reduction. (Example:
and b, and returns c code for their reduction. (Example:
...
@@ -183,8 +186,16 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
...
@@ -183,8 +186,16 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
:note: buf should be in gpu shared memory, we access it many times.
:note: buf should be in gpu shared memory, we access it many times.
"""
"""
init
=
manner_init
(
"
%(x)
s[
%(pos)
s *
%(stride_x)
s]"
%
locals
())
if
b
:
loop_line
=
manner_fn
(
"red"
,
manner_init
(
"
%
s[i]"
%
x
))
init
=
manner_init
(
"
%(x)
s[
%(pos)
s *
%(stride_x)
s] +"
"
%(b)
s[
%(pos)
s *
%(stride_b)
s]"
%
locals
())
loop_line
=
manner_fn
(
"red"
,
manner_init
(
"
%(x)
s[i *
%(stride_x)
s] + "
"
%(b)
s[i *
%(stride_b)
s]"
%
locals
()))
else
:
init
=
manner_init
(
"
%(x)
s[
%(pos)
s *
%(stride_x)
s]"
%
locals
())
loop_line
=
manner_fn
(
"red"
,
manner_init
(
"
%(x)
s[i *
%(stride_x)
s]"
%
locals
()))
loop_line2
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
loop_line2
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[i]"
%
buf
)
"
%
s[i]"
%
buf
)
r_16
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+16]"
%
(
buf
,
pos
))
r_16
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+16]"
%
(
buf
,
pos
))
...
@@ -229,17 +240,20 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
...
@@ -229,17 +240,20 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
@code_version
(
inline_reduce_fixed_shared
.
code_version
)
@code_version
(
inline_reduce_fixed_shared
.
code_version
)
def
inline_reduce_fixed_shared_max
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
):
def
inline_reduce_fixed_shared_max
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
b
=
''
,
stride_b
=
''
):
return
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
return
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
lambda
a
,
b
:
"max(
%
s,
%
s)"
%
(
a
,
b
),
lambda
a
,
b
:
"max(
%
s,
%
s)"
%
(
a
,
b
),
lambda
a
:
a
)
lambda
a
:
a
,
b
,
stride_b
)
@code_version
((
1
,)
+
inline_reduce_max
.
code_version
+
@code_version
((
1
,)
+
inline_reduce_max
.
code_version
+
inline_reduce_sum
.
code_version
)
inline_reduce_sum
.
code_version
)
def
inline_softmax_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
def
inline_softmax_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
sm
,
sm_stride
,
sm
,
sm_stride
,
threadPos
,
threadCount
):
threadPos
,
threadCount
,
b
=
''
,
stride_b
=
''
):
"""
"""
:param N: length of the buffer, atleast waprSize(32).
:param N: length of the buffer, atleast waprSize(32).
...
@@ -250,29 +264,39 @@ def inline_softmax_fixed_shared(N, buf, x, stride_x,
...
@@ -250,29 +264,39 @@ def inline_softmax_fixed_shared(N, buf, x, stride_x,
:param sm_stride: the stride between eash sm element
:param sm_stride: the stride between eash sm element
:param threadPos: index of executing thread
:param threadPos: index of executing thread
:param threadCount: number of executing threads
:param threadCount: number of executing threads
:param b: Optional, pointer to the bias
:param stride_b: Optional, the stride of b if b is provided
:Precondition: buf is empty
:Precondition: buf is empty
:Postcondition: buf[0] contains the softmax, buf2 contains un-normalized softmax
:Postcondition: buf[0] contains the softmax, buf2 contains un-normalized softmax
:note: buf
and buf2
should be in gpu shared memory, we access it many times.
:note: buf should be in gpu shared memory, we access it many times.
:note2: We use
__i
as an int variable in a loop
:note2: We use
tx
as an int variable in a loop
"""
"""
ret
urn
[
ret
=
[
#get max of buf (trashing all but buf[0])
#get max of buf (trashing all but buf[0])
inline_reduce_fixed_shared_max
(
N
,
buf
,
x
,
stride_x
,
threadPos
,
threadCount
),
inline_reduce_fixed_shared_max
(
N
,
buf
,
x
,
stride_x
,
threadPos
,
threadCount
,
b
,
stride_b
),
'__syncthreads()'
,
'__syncthreads()'
,
'float row_max = '
+
buf
+
'[0]'
,
'float row_max = '
+
buf
+
'[0]'
,
'__syncthreads()'
,
'__syncthreads()'
,
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
threadPos
,
threadCount
,
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
threadPos
,
threadCount
,
lambda
a
,
b
:
"
%
s +
%
s"
%
(
a
,
b
),
lambda
a
,
b
:
"
%
s +
%
s"
%
(
a
,
b
),
lambda
a
:
"exp(
%
s - row_max)"
%
a
),
lambda
a
:
"exp(
%
s - row_max)"
%
a
,
b
,
stride_b
),
'__syncthreads()'
,
'__syncthreads()'
,
'float row_sum = '
+
buf
+
'[0]'
,
'float row_sum = '
+
buf
+
'[0]'
,
'__syncthreads()'
,
'__syncthreads()'
,
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
# This set all value correctly
]
"
%(sm)
s[tx *
%(sm_stride)
s] = exp(
%(x)
s[tx *
%(stride_x)
s] - row_max) / row_sum"
%
locals
(),
# This set all value correctly
if
b
:
ret
+=
[
"
%(sm)
s[tx *
%(sm_stride)
s] = exp(
%(x)
s[tx *
%(stride_x)
s] +
%(b)
s[tx *
%(stride_b)
s]- row_max) / row_sum"
%
locals
()]
else
:
ret
+=
[
"
%(sm)
s[tx *
%(sm_stride)
s] = exp(
%(x)
s[tx *
%(stride_x)
s] - row_max) / row_sum"
%
locals
()]
ret
+=
[
"}"
,
"}"
,
'__syncthreads()'
,
'__syncthreads()'
,
]
]
return
ret
\ No newline at end of file
theano/sandbox/cuda/nnet.py
浏览文件 @
5b6f228e
...
@@ -506,7 +506,7 @@ class GpuSoftmaxWithBias (GpuOp):
...
@@ -506,7 +506,7 @@ class GpuSoftmaxWithBias (GpuOp):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
#return ()
#return ()
return
(
7
,)
+
inline_softmax
.
code_version
return
(
8
,)
+
inline_softmax
.
code_version
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
x
,
b
=
inp
x
,
b
=
inp
...
@@ -550,10 +550,9 @@ class GpuSoftmaxWithBias (GpuOp):
...
@@ -550,10 +550,9 @@ class GpuSoftmaxWithBias (GpuOp):
int n_shared_bytes = CudaNdarray_HOST_DIMS(
%(x)
s)[1] * 2 * sizeof(float);
int n_shared_bytes = CudaNdarray_HOST_DIMS(
%(x)
s)[1] * 2 * sizeof(float);
if (CudaNdarray_HOST_DIMS(
%(x)
s)[0] > 0)
if (CudaNdarray_HOST_DIMS(
%(x)
s)[0] > 0)
{
{
if(n_shared_bytes < (32 * 1024 - 500)){
kSoftmaxWithBias_
%(nodename)
s
kSoftmaxWithBias_
%(nodename)
s
<<<
<<<
// todo: cap these at the card limits,
// implement loops in kernel
n_blocks,
n_blocks,
n_threads,
n_threads,
n_shared_bytes
n_shared_bytes
...
@@ -572,6 +571,28 @@ class GpuSoftmaxWithBias (GpuOp):
...
@@ -572,6 +571,28 @@ class GpuSoftmaxWithBias (GpuOp):
CudaNdarray_HOST_STRIDES(
%(z)
s)[0],
CudaNdarray_HOST_STRIDES(
%(z)
s)[0],
CudaNdarray_HOST_STRIDES(
%(z)
s)[1]
CudaNdarray_HOST_STRIDES(
%(z)
s)[1]
);
);
}else{
kSoftmaxWithBias_fixed_shared
%(nodename)
s
<<<
n_blocks,
n_threads,
n_threads * sizeof(float)
>>>(
CudaNdarray_HOST_DIMS(
%(x)
s)[0],
CudaNdarray_HOST_DIMS(
%(x)
s)[1],
CudaNdarray_DEV_DATA(
%(x)
s),
CudaNdarray_HOST_STRIDES(
%(x)
s)[0],
CudaNdarray_HOST_STRIDES(
%(x)
s)[1],
CudaNdarray_DEV_DATA(
%(b)
s),
CudaNdarray_HOST_STRIDES(
%(b)
s)[0],
CudaNdarray_DEV_DATA(
%(z)
s),
CudaNdarray_HOST_STRIDES(
%(z)
s)[0],
CudaNdarray_HOST_STRIDES(
%(z)
s)[1]
);
}
CNDA_THREAD_SYNC;
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
if( cudaSuccess != err)
...
@@ -588,7 +609,7 @@ class GpuSoftmaxWithBias (GpuOp):
...
@@ -588,7 +609,7 @@ class GpuSoftmaxWithBias (GpuOp):
"""
%
locals
()
"""
%
locals
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
ret
urn
nvcc_kernel
(
"kSoftmaxWithBias_
%
s"
%
nodename
,
ret
1
=
nvcc_kernel
(
"kSoftmaxWithBias_
%
s"
%
nodename
,
params
=
[
'int M'
,
'int N'
,
params
=
[
'int M'
,
'int N'
,
'const float * x'
,
'const int sx0'
,
'const int sx1'
,
'const float * x'
,
'const int sx0'
,
'const int sx1'
,
'const float * b'
,
'const int sb0'
,
'const float * b'
,
'const int sb0'
,
...
@@ -610,5 +631,23 @@ class GpuSoftmaxWithBias (GpuOp):
...
@@ -610,5 +631,23 @@ class GpuSoftmaxWithBias (GpuOp):
"__syncthreads()"
,
"__syncthreads()"
,
"}"
,
"}"
,
])
])
ret2
=
nvcc_kernel
(
"kSoftmaxWithBias_fixed_shared
%
s"
%
nodename
,
params
=
[
'int M'
,
'int N'
,
'const float * x'
,
'const int sx0'
,
'const int sx1'
,
'const float * b'
,
'const int sb0'
,
'float * sm'
,
'const int sm_s0'
,
'const int sm_s1'
],
body
=
[
"extern __shared__ float buf[]"
,
"for (int blockIDX = blockIdx.x; blockIDX < M; blockIDX += gridDim.x){"
,
"const float *x_ptr = &x[blockIDX * sx0]"
,
"float *sm_ptr = &sm[blockIDX * sm_s0]"
,
inline_softmax_fixed_shared
(
'N'
,
'buf'
,
'x_ptr'
,
'sx1'
,
'sm_ptr'
,
'sm_s1'
,
'threadIdx.x'
,
'blockDim.x'
,
'b'
,
'sb0'
),
"__syncthreads()"
,
"}"
,
])
return
ret1
+
"
\n
"
+
ret2
gpu_softmax_with_bias
=
GpuSoftmaxWithBias
()
gpu_softmax_with_bias
=
GpuSoftmaxWithBias
()
theano/sandbox/cuda/tests/test_nnet.py
浏览文件 @
5b6f228e
...
@@ -172,8 +172,8 @@ def test_softmax_with_bias():
...
@@ -172,8 +172,8 @@ def test_softmax_with_bias():
x
=
T
.
fmatrix
(
'x'
)
x
=
T
.
fmatrix
(
'x'
)
# We can't use zeros_like(x[0,::]) as this don't allow to test with
# We can't use zeros_like(x[0,::]) as this don't allow to test with
# 0 shape.
# 0 shape.
z
=
T
.
nnet
.
softmax_with_bias
(
x
,
T
.
a
lloc
(
numpy
.
asarray
(
0
,
dtype
=
'float32'
)
,
z
=
T
.
nnet
.
softmax_with_bias
(
x
,
T
.
a
range
(
x
.
shape
[
1
]
*
2
,
x
.
shape
[
1
])
)
dtype
=
'float32'
)[::
2
]
)
f
=
theano
.
function
([
x
],
z
,
mode
=
mode_without_gpu
)
f
=
theano
.
function
([
x
],
z
,
mode
=
mode_without_gpu
)
f_gpu
=
theano
.
function
([
x
],
z
,
mode
=
mode_with_gpu
)
f_gpu
=
theano
.
function
([
x
],
z
,
mode
=
mode_with_gpu
)
...
@@ -181,24 +181,12 @@ def test_softmax_with_bias():
...
@@ -181,24 +181,12 @@ def test_softmax_with_bias():
assert
isinstance
(
f_gpu
.
maker
.
fgraph
.
toposort
()[
-
2
]
.
op
,
assert
isinstance
(
f_gpu
.
maker
.
fgraph
.
toposort
()[
-
2
]
.
op
,
cuda
.
nnet
.
GpuSoftmaxWithBias
)
cuda
.
nnet
.
GpuSoftmaxWithBias
)
def
cmp
(
n
,
m
,
catch
=
False
):
def
cmp
(
n
,
m
):
"""Some old card won't accet the configuration arguments of
#print "test_softmax",n,m
this implementation. For those cases set catch=True to skip
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float32'
)
.
reshape
(
n
,
m
)
those errors.
out
=
f
(
data
)
"""
gout
=
f_gpu
(
data
)
try
:
assert
numpy
.
allclose
(
out
,
gout
),
numpy
.
absolute
(
out
-
gout
)
#print "test_softmax",n,m
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float32'
)
.
reshape
(
n
,
m
)
out
=
f
(
data
)
gout
=
f_gpu
(
data
)
assert
numpy
.
allclose
(
out
,
gout
),
numpy
.
absolute
(
out
-
gout
)
except
RuntimeError
,
e
:
if
not
catch
:
raise
# Different CUDA driver have different error message
assert
(
e
.
args
[
0
]
.
startswith
(
'Cuda error: kSoftmaxWithBias_node_0: invalid configuration argument.
\n
'
)
or
e
.
args
[
0
]
.
startswith
(
'Cuda error: kSoftmaxWithBias_node_0: invalid argument.
\n
'
))
cmp
(
2
,
5
)
cmp
(
2
,
5
)
#we need to test n>32*1024 to check that we make the block loop.
#we need to test n>32*1024 to check that we make the block loop.
...
@@ -211,7 +199,11 @@ def test_softmax_with_bias():
...
@@ -211,7 +199,11 @@ def test_softmax_with_bias():
cmp
(
4
,
2000
)
cmp
(
4
,
2000
)
cmp
(
4
,
2024
)
cmp
(
4
,
2024
)
#GTX285 don't have enough shared mem for this case.
#GTX285 don't have enough shared mem for this case.
cmp
(
4
,
4074
,
True
)
cmp
(
4
,
4074
)
# The GTX580, 680 and kepler don't have enough shared memory.
cmp
(
2
,
10000
)
cmp
(
128
,
16
*
1024
)
cmp
(
128
,
64
*
1024
)
def
test_softmax
():
def
test_softmax
():
...
@@ -231,11 +223,6 @@ def test_softmax():
...
@@ -231,11 +223,6 @@ def test_softmax():
cuda
.
nnet
.
GpuSoftmax
)
cuda
.
nnet
.
GpuSoftmax
)
def
cmp
(
n
,
m
):
def
cmp
(
n
,
m
):
"""Some old card won't accept the configuration arguments of
this implementation. For those cases set catch=True to skip
those errors.
"""
#print "test_softmax",n,m
#print "test_softmax",n,m
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float32'
)
.
reshape
(
n
,
m
)
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float32'
)
.
reshape
(
n
,
m
)
out
=
f
(
data
)
out
=
f
(
data
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论