Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ff0abb5f
提交
ff0abb5f
authored
2月 14, 2014
作者:
abergeron
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1726 from carriepl/master
Conversion of GpuSoftmax and GpuSoftmaxWithBias to the new backend
上级
db4352dc
eaab9d97
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
878 行增加
和
3 行删除
+878
-3
kernel_codegen.py
theano/sandbox/gpuarray/kernel_codegen.py
+320
-0
nnet.py
theano/sandbox/gpuarray/nnet.py
+416
-2
opt.py
theano/sandbox/gpuarray/opt.py
+12
-1
test_nnet.py
theano/sandbox/gpuarray/tests/test_nnet.py
+130
-0
没有找到文件。
theano/sandbox/gpuarray/kernel_codegen.py
0 → 100644
浏览文件 @
ff0abb5f
""" Helper routines for generating gpu kernels for nvcc.
"""
def
nvcc_kernel
(
name
,
params
,
body
):
"""Return the c code of a kernel function.
:param params: the parameters to the function as one or more strings
:param body: the [nested] list of statements for the body of the
function. These will be separated by ';' characters.
"""
paramstr
=
', '
.
join
(
params
)
def
flatbody
():
for
b
in
body
:
if
isinstance
(
b
,
(
list
,
tuple
)):
for
bb
in
b
:
yield
bb
else
:
yield
b
bodystr
=
';
\n
'
.
join
(
flatbody
())
return
"""__global__ void
%(name)
s (
%(paramstr)
s)
{
%(bodystr)
s;
}
"""
%
locals
()
def
code_version
(
version
):
"""decorator to support version-based cache mechanism"""
if
not
isinstance
(
version
,
tuple
):
raise
TypeError
(
'version must be tuple'
,
version
)
def
deco
(
f
):
f
.
code_version
=
version
return
f
return
deco
UNVERSIONED
=
()
@code_version
((
1
,))
def
inline_reduce
(
N
,
buf
,
pos
,
count
,
manner_fn
):
"""Return C++ code for a function that reduces a contiguous buffer.
:param N: length of the buffer
:param buf: buffer pointer
:param pos: index of executing thread
:param count: number of executing threads
:param manner_fn: a function that accepts strings of arguments a
and b, and returns c code for their reduction. (Example:
return "
%(a)
s +
%(b)
s" for a sum reduction).
:postcondition:
This function leaves the answer in position 0 of the buffer. The
rest of the buffer is trashed by this function.
:note: buf should be in gpu shared memory, we access it many times.
"""
loop_line
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[i]"
%
(
buf
))
r_16
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+16]"
%
(
buf
,
pos
))
r_8
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+8]"
%
(
buf
,
pos
))
r_4
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+4]"
%
(
buf
,
pos
))
r_2
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+2]"
%
(
buf
,
pos
))
r_1
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+1]"
%
(
buf
,
pos
))
return
"""
{
// This function trashes buf[1..warpSize],
// leaving the reduction result in buf[0].
if (
%(pos)
s < warpSize)
{
for (int i =
%(pos)
s + warpSize; i <
%(N)
s; i += warpSize)
{
%(buf)
s[
%(pos)
s] =
%(loop_line)
s;
}
if (
%(pos)
s < 16)
{
//reduce so that
%(pos)
s 0 has the sum of everything
if(
%(pos)
s + 16 <
%(N)
s)
%(buf)
s[
%(pos)
s] =
%(r_16)
s;
if(
%(pos)
s + 8 <
%(N)
s)
%(buf)
s[
%(pos)
s] =
%(r_8)
s;
if(
%(pos)
s + 4 <
%(N)
s)
%(buf)
s[
%(pos)
s] =
%(r_4)
s;
if(
%(pos)
s + 2 <
%(N)
s)
%(buf)
s[
%(pos)
s] =
%(r_2)
s;
if(
%(pos)
s + 1 <
%(N)
s)
%(buf)
s[
%(pos)
s] =
%(r_1)
s;
}
}
}
"""
%
locals
()
@code_version
(
inline_reduce
.
code_version
)
def
inline_reduce_max
(
N
,
buf
,
pos
,
count
):
return
inline_reduce
(
N
,
buf
,
pos
,
count
,
lambda
a
,
b
:
"max(
%
s,
%
s)"
%
(
a
,
b
))
@code_version
(
inline_reduce
.
code_version
)
def
inline_reduce_sum
(
N
,
buf
,
pos
,
count
):
return
inline_reduce
(
N
,
buf
,
pos
,
count
,
lambda
a
,
b
:
"
%
s +
%
s"
%
(
a
,
b
))
@code_version
(
inline_reduce
.
code_version
)
def
inline_reduce_min
(
N
,
buf
,
pos
,
count
):
return
inline_reduce
(
N
,
buf
,
pos
,
count
,
lambda
a
,
b
:
"min(
%
s,
%
s)"
%
(
a
,
b
))
@code_version
(
inline_reduce
.
code_version
)
def
inline_reduce_prod
(
N
,
buf
,
pos
,
count
):
return
inline_reduce
(
N
,
buf
,
pos
,
count
,
lambda
a
,
b
:
"
%
s *
%
s"
%
(
a
,
b
))
@code_version
((
2
,)
+
inline_reduce_max
.
code_version
+
inline_reduce_sum
.
code_version
)
def
inline_softmax
(
N
,
buf
,
buf2
,
threadPos
,
threadCount
,
dtype
=
"float32"
):
"""
:param N: length of the buffer
:param threadPos: index of executing thread
:param threadCount: number of executing threads
:param dtype: dtype of the softmax's output
:Precondition: buf and buf2 contain two identical copies of the input
to softmax
:Postcondition: buf contains the softmax, buf2 contains un-normalized
softmax
:note: buf and buf2 should be in gpu shared memory, we access it many times
:note2: We use __i as an int variable in a loop
"""
return
[
#get max of buf (trashing all but buf[0])
inline_reduce_max
(
N
,
buf
,
threadPos
,
threadCount
),
'__syncthreads()'
,
(
'npy_
%
s row_max = '
+
buf
+
'[0]'
)
%
dtype
,
'__syncthreads()'
,
'for(int __i='
+
threadPos
+
'; __i<'
+
N
+
'; __i+='
+
threadCount
+
'){'
,
buf
+
'[__i] = exp('
+
buf2
+
'[__i] - row_max)'
,
buf2
+
'[__i] = '
+
buf
+
'[__i]'
,
'}'
,
'__syncthreads()'
,
inline_reduce_sum
(
N
,
buf
,
threadPos
,
threadCount
),
'__syncthreads()'
,
(
'npy_
%
s row_sum = '
+
buf
+
'[0]'
)
%
dtype
,
'__syncthreads()'
,
# divide each exp() result by the sum to complete the job.
'for(int __i='
+
threadPos
+
'; __i<'
+
N
+
'; __i+='
+
threadCount
+
'){'
,
buf
+
'[__i] = '
+
buf2
+
'[__i] / row_sum'
,
'}'
,
'__syncthreads()'
,
]
@code_version
((
1
,))
def
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
manner_fn
,
manner_init
,
b
=
''
,
stride_b
=
''
,
dtype
=
'float32'
):
"""Return C++ code for a function that reduces a contiguous buffer.
:param N: length of the buffer
:param buf: buffer pointer of size warpSize * sizeof(dtype)
:param pos: index of executing thread
:param count: number of executing threads
:param b: Optional, pointer to the bias
:param stride_b: Optional, the stride of b if b is provided
:param dtype: Optional, the dtype of the output
:param manner_fn: a function that accepts strings of arguments a
and b, and returns c code for their reduction. (Example:
return "
%(a)
s +
%(b)
s" for a sum reduction).
:param manner_init: a function that accepts strings of arguments a
and return c code for its initialization
:postcondition:
This function leaves the answer in position 0 of the buffer. The
rest of the buffer is trashed by this function.
:note: buf should be in gpu shared memory, we access it many times.
"""
if
b
:
init
=
manner_init
(
"
%(x)
s[
%(pos)
s *
%(stride_x)
s] +"
"
%(b)
s[
%(pos)
s *
%(stride_b)
s]"
%
locals
())
loop_line
=
manner_fn
(
"red"
,
manner_init
(
"
%(x)
s[i *
%(stride_x)
s] + "
"
%(b)
s[i *
%(stride_b)
s]"
%
locals
()))
else
:
init
=
manner_init
(
"
%(x)
s[
%(pos)
s *
%(stride_x)
s]"
%
locals
())
loop_line
=
manner_fn
(
"red"
,
manner_init
(
"
%(x)
s[i *
%(stride_x)
s]"
%
locals
()))
loop_line2
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[i]"
%
buf
)
r_16
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+16]"
%
(
buf
,
pos
))
r_8
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+8]"
%
(
buf
,
pos
))
r_4
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+4]"
%
(
buf
,
pos
))
r_2
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+2]"
%
(
buf
,
pos
))
r_1
=
manner_fn
(
"
%
s[
%
s]"
%
(
buf
,
pos
),
"
%
s[
%
s+1]"
%
(
buf
,
pos
))
return
"""
{
// This function trashes buf[1..n_threads],
// leaving the reduction result in buf[0].
npy_
%(dtype)
s red =
%(init)
s;
#pragma unroll 16
for (int i =
%(pos)
s +
%(count)
s; i<
%(N)
s; i +=
%(count)
s){
red =
%(loop_line)
s;
}
buf[
%(pos)
s] = red;
__syncthreads();
if (
%(pos)
s < warpSize)
{
for (int i =
%(pos)
s + warpSize; i <
%(count)
s; i += warpSize)
{
%(buf)
s[
%(pos)
s] =
%(loop_line2)
s;
}
if (
%(pos)
s < 16)
{
//reduce so that
%(pos)
s 0 has the reduction of everything
if(
%(pos)
s + 16 <
%(N)
s)
%(buf)
s[
%(pos)
s] =
%(r_16)
s;
if(
%(pos)
s + 8 <
%(N)
s)
%(buf)
s[
%(pos)
s] =
%(r_8)
s;
if(
%(pos)
s + 4 <
%(N)
s)
%(buf)
s[
%(pos)
s] =
%(r_4)
s;
if(
%(pos)
s + 2 <
%(N)
s)
%(buf)
s[
%(pos)
s] =
%(r_2)
s;
if(
%(pos)
s + 1 <
%(N)
s)
%(buf)
s[
%(pos)
s] =
%(r_1)
s;
}
}
}
"""
%
locals
()
@code_version
(
inline_reduce_fixed_shared
.
code_version
)
def
inline_reduce_fixed_shared_max
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
b
=
''
,
stride_b
=
''
,
dtype
=
'float32'
):
return
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
pos
,
count
,
lambda
a
,
b
:
"max(
%
s,
%
s)"
%
(
a
,
b
),
lambda
a
:
a
,
b
,
stride_b
,
dtype
)
@code_version
((
1
,)
+
inline_reduce_max
.
code_version
+
inline_reduce_sum
.
code_version
)
def
inline_softmax_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
sm
,
sm_stride
,
threadPos
,
threadCount
,
b
=
''
,
stride_b
=
''
,
dtype
=
"float32"
):
"""
:param N: length of the buffer, atleast waprSize(32).
:param buf: a shared memory buffer of size warpSize * sizeof(dtype)
:param x: a ptr to the gpu memory where the row is stored
:param stride_x: the stride between each element in x
:param sm: a ptr to the gpu memory to store the result
:param sm_stride: the stride between eash sm element
:param threadPos: index of executing thread
:param threadCount: number of executing threads
:param b: Optional, pointer to the bias
:param stride_b: Optional, the stride of b if b is provided
:param dtype: Optional, the dtype of the softmax's output if not float32
:Precondition: buf is empty
:Postcondition: buf[0] contains the softmax,
buf2 contains un-normalized softmax
:note: buf should be in gpu shared memory, we access it many times.
:note2: We use tx as an int variable in a loop
"""
ret
=
[
#get max of buf (trashing all but buf[0])
inline_reduce_fixed_shared_max
(
N
,
buf
,
x
,
stride_x
,
threadPos
,
threadCount
,
b
,
stride_b
,
dtype
),
'__syncthreads()'
,
(
'npy_
%
s row_max = '
+
buf
+
'[0]'
)
%
dtype
,
'__syncthreads()'
,
inline_reduce_fixed_shared
(
N
,
buf
,
x
,
stride_x
,
threadPos
,
threadCount
,
lambda
a
,
b
:
"
%
s +
%
s"
%
(
a
,
b
),
lambda
a
:
"exp(
%
s - row_max)"
%
a
,
b
,
stride_b
,
dtype
),
'__syncthreads()'
,
(
'npy_
%
s row_sum = '
+
buf
+
'[0]'
)
%
dtype
,
'__syncthreads()'
,
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
]
# This set all value correctly
if
b
:
ret
+=
[
"
%(sm)
s[tx *
%(sm_stride)
s] = "
" exp(
%(x)
s[tx *
%(stride_x)
s] +"
"
%(b)
s[tx *
%(stride_b)
s] - row_max)"
" / row_sum"
%
locals
()]
else
:
ret
+=
[
"
%(sm)
s[tx *
%(sm_stride)
s] = "
"exp(
%(x)
s[tx *
%(stride_x)
s] - row_max) / row_sum"
%
locals
()]
ret
+=
[
"}"
,
'__syncthreads()'
,
]
return
ret
theano/sandbox/gpuarray/nnet.py
浏览文件 @
ff0abb5f
import
numpy
import
numpy
from
theano
import
Op
,
Apply
from
theano
import
Op
,
Apply
,
config
from
theano.compat.six
import
StringIO
from
theano.compat.six
import
StringIO
from
theano.sandbox.cuda.nvcc_compiler
import
NVCC_compiler
from
theano.sandbox.cuda.nvcc_compiler
import
NVCC_compiler
...
@@ -14,6 +13,10 @@ except ImportError:
...
@@ -14,6 +13,10 @@ except ImportError:
from
theano.sandbox.gpuarray.basic_ops
import
as_gpuarray_variable
from
theano.sandbox.gpuarray.basic_ops
import
as_gpuarray_variable
from
theano.sandbox.gpuarray.type
import
GpuArrayType
from
theano.sandbox.gpuarray.type
import
GpuArrayType
from
theano.sandbox.gpuarray.kernel_codegen
import
(
nvcc_kernel
,
inline_softmax
,
inline_softmax_fixed_shared
)
class
GpuCrossentropySoftmaxArgmax1HotWithBias
(
Op
):
class
GpuCrossentropySoftmaxArgmax1HotWithBias
(
Op
):
...
@@ -440,3 +443,413 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
...
@@ -440,3 +443,413 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
return
[
'cuda_get_ptr = (CUdeviceptr (*)(gpudata *g))compyte_get_extension("cuda_get_ptr");'
]
return
[
'cuda_get_ptr = (CUdeviceptr (*)(gpudata *g))compyte_get_extension("cuda_get_ptr");'
]
gpu_crossentropy_softmax_1hot_with_bias_dx
=
GpuCrossentropySoftmax1HotWithBiasDx
()
gpu_crossentropy_softmax_1hot_with_bias_dx
=
GpuCrossentropySoftmax1HotWithBiasDx
()
class
GpuSoftmax
(
Op
):
"""
Implement Softmax on the gpu.
"""
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
__str__
(
self
):
return
self
.
__class__
.
__name__
def
make_node
(
self
,
x
):
x
=
as_gpuarray_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
infer_shape
(
self
,
node
,
shape
):
return
shape
def
c_code_cache_version
(
self
):
return
(
12
,)
+
inline_softmax
.
code_version
def
c_headers
(
self
):
return
[
'cuda.h'
,
'<compyte/extension.h>'
,
'<numpy_compat.h>'
,
'<compyte/ext_cuda.h>'
]
def
c_compiler
(
self
):
return
NVCC_compiler
def
c_init_code
(
self
):
return
[
'setup_ext_cuda();'
]
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
dtype_x
=
node
.
inputs
[
0
]
.
dtype
dtype_z
=
node
.
outputs
[
0
]
.
dtype
itemsize_x
=
numpy
.
dtype
(
dtype_x
)
.
itemsize
itemsize_z
=
numpy
.
dtype
(
dtype_z
)
.
itemsize
typecode
=
pygpu
.
gpuarray
.
dtype_to_typecode
(
node
.
outputs
[
0
]
.
dtype
)
x
,
=
inp
z
,
=
out
fail
=
sub
[
'fail'
]
if
config
.
gpuarray
.
sync
:
cnda_thread_sync
=
"GpuArray_sync(&
%(zz)
s->ga);"
%
dict
(
zz
=
zz
)
else
:
cnda_thread_sync
=
""
return
"""
if (PyGpuArray_NDIM(
%(x)
s) != 2)
{
PyErr_SetString(PyExc_ValueError, "rank error");
%(fail)
s;
}
if ((NULL ==
%(z)
s) ||
(PyGpuArray_DIMS(
%(z)
s)[0] !=
PyGpuArray_DIMS(
%(x)
s)[0]) ||
(PyGpuArray_DIMS(
%(z)
s)[1] !=
PyGpuArray_DIMS(
%(x)
s)[1]))
{
Py_XDECREF(
%(z)
s);
%(z)
s = pygpu_empty(2, PyGpuArray_DIMS(
%(x)
s),
%(typecode)
s,
GA_C_ORDER,
pygpu_default_context(), Py_None);
if (!
%(z)
s) {
%(fail)
s
}
}
{
int n_blocks = std::min(PyGpuArray_DIMS(
%(x)
s)[0],
(size_t)(32 * 1024));
//TODO, detect the maximum number of thread per block.
int n_threads = std::min(PyGpuArray_DIMS(
%(x)
s)[1], (size_t)512);
int n_shared_bytes = PyGpuArray_DIMS(
%(x)
s)[1] *
2 * sizeof(npy_
%(dtype_x)
s);
if (PyGpuArray_DIMS(
%(x)
s)[0] > 0)
{
//Those numbers are based on not too recent GPU
//to make them compatible with more GPU.
//TODO: read the information from the card.
if(n_shared_bytes < (32 * 1024 - 500)){
kSoftmax_
%(nodename)
s
<<<
n_blocks,
n_threads,
n_shared_bytes
>>>(
PyGpuArray_DIMS(
%(x)
s)[0],
PyGpuArray_DIMS(
%(x)
s)[1],
(npy_
%(dtype_x)
s*)(
((char *)cuda_get_ptr(
%(x)
s->ga.data)) +
%(x)
s->ga.offset),
PyGpuArray_STRIDES(
%(x)
s)[0] /
%(itemsize_x)
s,
PyGpuArray_STRIDES(
%(x)
s)[1] /
%(itemsize_x)
s,
(npy_
%(dtype_z)
s*)(
((char *)cuda_get_ptr(
%(z)
s->ga.data)) +
%(z)
s->ga.offset),
PyGpuArray_STRIDES(
%(z)
s)[0] /
%(itemsize_z)
s,
PyGpuArray_STRIDES(
%(z)
s)[1] /
%(itemsize_z)
s
);
}else{
kSoftmax_fixed_shared
%(nodename)
s
<<<
n_blocks,
n_threads,
n_threads * sizeof(npy_
%(dtype_x)
s)
>>>(
PyGpuArray_DIMS(
%(x)
s)[0],
PyGpuArray_DIMS(
%(x)
s)[1],
(npy_
%(dtype_x)
s*)(
((char *)cuda_get_ptr(
%(x)
s->ga.data)) +
%(x)
s->ga.offset),
PyGpuArray_STRIDES(
%(x)
s)[0] /
%(itemsize_x)
s,
PyGpuArray_STRIDES(
%(x)
s)[1] /
%(itemsize_x)
s,
(npy_
%(dtype_z)
s*)(
((char *)cuda_get_ptr(
%(z)
s->ga.data)) +
%(z)
s->ga.offset),
PyGpuArray_STRIDES(
%(z)
s)[0] /
%(itemsize_z)
s,
PyGpuArray_STRIDES(
%(z)
s)[1] /
%(itemsize_z)
s
);
}
%(cnda_thread_sync)
s
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError,
"Cuda error:
%%
s:
%%
s.
\\
n Used
%%
d blocks,"
"
%%
d threads
%%
d bytes of shared memory",
"kSoftmax[_fixed_shared]
%(nodename)
s",
cudaGetErrorString(err),
n_blocks, n_threads, n_shared_bytes);
%(fail)
s;
}
}
}
assert(
%(z)
s);
"""
%
locals
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
dtype_x
=
node
.
inputs
[
0
]
.
dtype
dtype_sm
=
node
.
outputs
[
0
]
.
dtype
ret1
=
nvcc_kernel
(
"kSoftmax_
%
s"
%
nodename
,
params
=
[
'int M'
,
'int N'
,
'const npy_
%(dtype_x)
s * x'
,
'const int sx0'
,
'const int sx1'
,
'npy_
%(dtype_sm)
s * sm'
,
'const int sm_s0'
,
'const int sm_s1'
],
body
=
[
"extern __shared__ npy_
%(dtype_sm)
s buf[]"
,
"npy_
%(dtype_sm)
s * buf2 = buf + N"
,
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){"
,
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
"buf[tx] = x[blockIDX * sx0 + tx * sx1]"
,
"buf2[tx] = buf[tx]"
,
"}"
,
"__syncthreads()"
,
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
,
dtype_sm
),
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
# This set all value correctly
"sm[blockIDX * sm_s0 + tx * sm_s1] = buf[tx]"
,
"}"
,
"__syncthreads()"
,
"}"
,
])
ret2
=
nvcc_kernel
(
"kSoftmax_fixed_shared
%
s"
%
nodename
,
params
=
[
'int M'
,
'int N'
,
'const npy_
%(dtype_x)
s * x'
,
'const int sx0'
,
'const int sx1'
,
'npy_
%(dtype_sm)
s * sm'
,
'const int sm_s0'
,
'const int sm_s1'
],
body
=
[
"extern __shared__ npy_
%(dtype_sm)
s buf[]"
,
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){"
,
"const npy_
%(dtype_x)
s *x_ptr = &x[blockIDX * sx0]"
,
"npy_
%(dtype_sm)
s *sm_ptr = &sm[blockIDX * sm_s0]"
,
inline_softmax_fixed_shared
(
'N'
,
'buf'
,
'x_ptr'
,
'sx1'
,
'sm_ptr'
,
'sm_s1'
,
'threadIdx.x'
,
'blockDim.x'
,
dtype
=
dtype_sm
),
"__syncthreads()"
,
"}"
,
])
return
(
ret1
+
"
\n
"
+
ret2
)
%
locals
()
gpu_softmax
=
GpuSoftmax
()
class
GpuSoftmaxWithBias
(
Op
):
"""
Implement SoftmaxWithBias on the gpu.
"""
nin
=
2
nout
=
1
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
__str__
(
self
):
return
self
.
__class__
.
__name__
def
make_node
(
self
,
x
,
b
):
x
=
as_gpuarray_variable
(
x
)
b
=
as_gpuarray_variable
(
b
)
return
Apply
(
self
,
[
x
,
b
],
[
x
.
type
()])
def
infer_shape
(
self
,
node
,
shape
):
return
[
shape
[
0
]]
def
c_code_cache_version
(
self
):
return
(
11
,)
+
inline_softmax
.
code_version
def
c_headers
(
self
):
return
[
'cuda.h'
,
'<compyte/extension.h>'
,
'<numpy_compat.h>'
,
'<compyte/ext_cuda.h>'
]
def
c_compiler
(
self
):
return
NVCC_compiler
def
c_init_code
(
self
):
return
[
'setup_ext_cuda();'
]
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
dtype_x
=
node
.
inputs
[
0
]
.
dtype
dtype_b
=
node
.
inputs
[
1
]
.
dtype
dtype_z
=
node
.
outputs
[
0
]
.
dtype
itemsize_x
=
numpy
.
dtype
(
dtype_x
)
.
itemsize
itemsize_b
=
numpy
.
dtype
(
dtype_b
)
.
itemsize
itemsize_z
=
numpy
.
dtype
(
dtype_z
)
.
itemsize
typecode
=
pygpu
.
gpuarray
.
dtype_to_typecode
(
node
.
outputs
[
0
]
.
dtype
)
x
,
b
=
inp
z
,
=
out
fail
=
sub
[
'fail'
]
if
config
.
gpuarray
.
sync
:
cnda_thread_sync
=
"GpuArray_sync(&
%(zz)
s->ga);"
%
dict
(
zz
=
zz
)
else
:
cnda_thread_sync
=
""
return
"""
if (PyGpuArray_NDIM(
%(x)
s) != 2)
{
PyErr_SetString(PyExc_ValueError, "rank error input");
%(fail)
s;
}
if (PyGpuArray_NDIM(
%(b)
s) != 1)
{
PyErr_SetString(PyExc_ValueError, "rank error for the bias");
%(fail)
s;
}
if ((PyGpuArray_DIMS(
%(x)
s)[1] !=
PyGpuArray_DIMS(
%(b)
s)[0]))
{
PyErr_Format(PyExc_ValueError,
"number of columns in x (
%%
ld)"
" does not match length of b (
%%
ld)",
(long int)PyGpuArray_DIMS(
%(x)
s)[1],
(long int)PyGpuArray_DIMS(
%(b)
s)[0]);
%(fail)
s;
}
if ((NULL ==
%(z)
s)
|| (PyGpuArray_DIMS(
%(z)
s)[0] !=
PyGpuArray_DIMS(
%(x)
s)[0])
|| (PyGpuArray_DIMS(
%(z)
s)[1] !=
PyGpuArray_DIMS(
%(x)
s)[1]))
{
Py_XDECREF(
%(z)
s);
%(z)
s = pygpu_empty(2, PyGpuArray_DIMS(
%(x)
s),
%(typecode)
s,
GA_C_ORDER,
pygpu_default_context(), Py_None);
if (!
%(z)
s) {
%(fail)
s
}
}
{
int n_blocks = std::min(PyGpuArray_DIMS(
%(x)
s)[0], (size_t)(32*1024));
//TODO, detect the maximum number of thread per block.
int n_threads = std::min(PyGpuArray_DIMS(
%(x)
s)[1], (size_t)512);
int n_shared_bytes = PyGpuArray_DIMS(
%(x)
s)[1] *
2 * sizeof(npy_
%(dtype_x)
s);
if (PyGpuArray_DIMS(
%(x)
s)[0] > 0)
{
if(n_shared_bytes < (32 * 1024 - 500)){
kSoftmaxWithBias_
%(nodename)
s
<<<
n_blocks,
n_threads,
n_shared_bytes
>>>(
PyGpuArray_DIMS(
%(x)
s)[0],
PyGpuArray_DIMS(
%(x)
s)[1],
(npy_
%(dtype_x)
s*)(
((char *)cuda_get_ptr(
%(x)
s->ga.data)) +
%(x)
s->ga.offset),
PyGpuArray_STRIDES(
%(x)
s)[0] /
%(itemsize_x)
s,
PyGpuArray_STRIDES(
%(x)
s)[1] /
%(itemsize_x)
s,
(npy_
%(dtype_b)
s*)(((char *)cuda_get_ptr(
%(b)
s->ga.data)) +
%(b)
s->ga.offset),
PyGpuArray_STRIDES(
%(b)
s)[0] /
%(itemsize_b)
s,
(npy_
%(dtype_z)
s*)(((char *)cuda_get_ptr(
%(z)
s->ga.data)) +
%(z)
s->ga.offset),
PyGpuArray_STRIDES(
%(z)
s)[0] /
%(itemsize_z)
s,
PyGpuArray_STRIDES(
%(z)
s)[1] /
%(itemsize_z)
s
);
}else{
kSoftmaxWithBias_fixed_shared
%(nodename)
s
<<<
n_blocks,
n_threads,
n_threads * sizeof(npy_
%(dtype_x)
s)
>>>(
PyGpuArray_DIMS(
%(x)
s)[0],
PyGpuArray_DIMS(
%(x)
s)[1],
(npy_
%(dtype_x)
s*)(
((char *)cuda_get_ptr(
%(x)
s->ga.data)) +
%(x)
s->ga.offset),
PyGpuArray_STRIDES(
%(x)
s)[0] /
%(itemsize_x)
s,
PyGpuArray_STRIDES(
%(x)
s)[1] /
%(itemsize_x)
s,
(npy_
%(dtype_b)
s*)(
((char *)cuda_get_ptr(
%(b)
s->ga.data)) +
%(b)
s->ga.offset),
PyGpuArray_STRIDES(
%(b)
s)[0] /
%(itemsize_b)
s,
(npy_
%(dtype_z)
s*)(
((char *)cuda_get_ptr(
%(z)
s->ga.data)) +
%(z)
s->ga.offset),
PyGpuArray_STRIDES(
%(z)
s)[0] /
%(itemsize_z)
s,
PyGpuArray_STRIDES(
%(z)
s)[1] /
%(itemsize_z)
s
);
}
%(cnda_thread_sync)
s
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError,
"Cuda error:
%%
s:
%%
s.
\\
n",
"kSoftmaxWithBias_
%(nodename)
s",
cudaGetErrorString(err));
%(fail)
s;
}
}
}
assert(
%(z)
s);
"""
%
locals
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
dtype_x
=
node
.
inputs
[
0
]
.
dtype
dtype_b
=
node
.
inputs
[
1
]
.
dtype
dtype_sm
=
node
.
outputs
[
0
]
.
dtype
ret1
=
nvcc_kernel
(
"kSoftmaxWithBias_
%
s"
%
nodename
,
params
=
[
'int M'
,
'int N'
,
'const npy_
%(dtype_x)
s * x'
,
'const int sx0'
,
'const int sx1'
,
'const npy_
%(dtype_b)
s * b'
,
'const int sb0'
,
'npy_
%(dtype_sm)
s * sm'
,
'const int sm_s0'
,
'const int sm_s1'
],
body
=
[
"extern __shared__ npy_
%(dtype_sm)
s buf[]"
,
"npy_
%(dtype_sm)
s * buf2 = buf + N"
,
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){"
,
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
"buf[tx] = x[blockIDX * sx0 + tx * sx1]"
,
"buf[tx] += b[tx * sb0]"
,
"buf2[tx] = buf[tx]"
,
"}"
,
"__syncthreads()"
,
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
,
dtype_sm
),
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
"sm[blockIDX * sm_s0 + tx * sm_s1] = buf[tx]"
,
"}"
,
"__syncthreads()"
,
"}"
,
])
ret2
=
nvcc_kernel
(
"kSoftmaxWithBias_fixed_shared
%
s"
%
nodename
,
params
=
[
'int M'
,
'int N'
,
'const npy_
%(dtype_x)
s * x'
,
'const int sx0'
,
'const int sx1'
,
'const npy_
%(dtype_b)
s * b'
,
'const int sb0'
,
'npy_
%(dtype_sm)
s * sm'
,
'const int sm_s0'
,
'const int sm_s1'
],
body
=
[
"extern __shared__ npy_
%(dtype_sm)
s buf[]"
,
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){"
,
"const npy_
%(dtype_x)
s *x_ptr = &x[blockIDX * sx0]"
,
"npy_
%(dtype_sm)
s *sm_ptr = &sm[blockIDX * sm_s0]"
,
inline_softmax_fixed_shared
(
'N'
,
'buf'
,
'x_ptr'
,
'sx1'
,
'sm_ptr'
,
'sm_s1'
,
'threadIdx.x'
,
'blockDim.x'
,
'b'
,
'sb0'
,
dtype_sm
),
"__syncthreads()"
,
"}"
,
])
return
(
ret1
+
"
\n
"
+
ret2
)
%
locals
()
gpu_softmax_with_bias
=
GpuSoftmaxWithBias
()
\ No newline at end of file
theano/sandbox/gpuarray/opt.py
浏览文件 @
ff0abb5f
...
@@ -20,7 +20,9 @@ from theano.sandbox.gpuarray.basic_ops import (host_from_gpu,
...
@@ -20,7 +20,9 @@ from theano.sandbox.gpuarray.basic_ops import (host_from_gpu,
from
theano.sandbox.gpuarray.blas
import
gpu_dot22
,
GpuGemv
,
GpuGemm
from
theano.sandbox.gpuarray.blas
import
gpu_dot22
,
GpuGemv
,
GpuGemm
from
theano.sandbox.gpuarray.conv
import
GpuConv
from
theano.sandbox.gpuarray.conv
import
GpuConv
from
theano.sandbox.gpuarray.nnet
import
(
GpuCrossentropySoftmaxArgmax1HotWithBias
,
from
theano.sandbox.gpuarray.nnet
import
(
GpuCrossentropySoftmaxArgmax1HotWithBias
,
GpuCrossentropySoftmax1HotWithBiasDx
)
GpuCrossentropySoftmax1HotWithBiasDx
,
GpuSoftmaxWithBias
,
GpuSoftmax
)
from
theano.sandbox.gpuarray.elemwise
import
(
GpuElemwise
,
_is_scalar
,
from
theano.sandbox.gpuarray.elemwise
import
(
GpuElemwise
,
_is_scalar
,
GpuDimShuffle
,
GpuCAReduceCuda
)
GpuDimShuffle
,
GpuCAReduceCuda
)
from
theano.sandbox.gpuarray.subtensor
import
GpuIncSubtensor
,
GpuSubtensor
from
theano.sandbox.gpuarray.subtensor
import
GpuIncSubtensor
,
GpuSubtensor
...
@@ -341,6 +343,15 @@ def local_gpua_crossentropysoftmaxargmax1hotwithbias(node):
...
@@ -341,6 +343,15 @@ def local_gpua_crossentropysoftmaxargmax1hotwithbias(node):
def
local_gpua_crossentropysoftmax1hotwithbiasdx
(
node
):
def
local_gpua_crossentropysoftmax1hotwithbiasdx
(
node
):
return
GpuCrossentropySoftmax1HotWithBiasDx
()
return
GpuCrossentropySoftmax1HotWithBiasDx
()
@register_opt
()
@op_lifter
([
tensor
.
nnet
.
Softmax
])
def
local_gpua_softmax
(
node
):
return
GpuSoftmax
()
@register_opt
()
@op_lifter
([
tensor
.
nnet
.
SoftmaxWithBias
])
def
local_gpua_softmaxwithbias
(
node
):
return
GpuSoftmaxWithBias
()
@register_opt
()
@register_opt
()
@op_lifter
([
gpu_from_host
,
ConvOp
])
@op_lifter
([
gpu_from_host
,
ConvOp
])
...
...
theano/sandbox/gpuarray/tests/test_nnet.py
浏览文件 @
ff0abb5f
...
@@ -157,3 +157,132 @@ def test_GpuCrossentropySoftmax1HotWithBiasDx():
...
@@ -157,3 +157,132 @@ def test_GpuCrossentropySoftmax1HotWithBiasDx():
assert
False
,
"numpy.allclose(cpu_out, gpu_out, rtol=
%
s, atol=
%
s)"
%
(
assert
False
,
"numpy.allclose(cpu_out, gpu_out, rtol=
%
s, atol=
%
s)"
%
(
rtol
,
atol
)
rtol
,
atol
)
def
test_softmax_with_bias_float32
():
softmax_with_bias_unittest_template
(
dtypeInput
=
'float32'
,
dtypeBias
=
'float32'
)
def
test_softmax_with_bias_float64
():
softmax_with_bias_unittest_template
(
dtypeInput
=
'float32'
,
dtypeBias
=
'float64'
)
softmax_with_bias_unittest_template
(
dtypeInput
=
'float64'
,
dtypeBias
=
'float32'
)
softmax_with_bias_unittest_template
(
dtypeInput
=
'float64'
,
dtypeBias
=
'float64'
)
def
softmax_with_bias_unittest_template
(
dtypeInput
,
dtypeBias
):
"""
This is basic test for GpuSoftmaxWithBias with float64 variables
We check that we loop when their is too much block
TODO: check that we loop when their is too much thread.(THIS IS
NOT IMPLEMENTED)
"""
assert
dtypeInput
in
[
'float32'
,
'float64'
]
assert
dtypeBias
in
[
'float32'
,
'float64'
]
if
dtypeInput
==
'float32'
:
x
=
T
.
fmatrix
(
'x'
)
elif
dtypeInput
==
'float64'
:
x
=
T
.
dmatrix
(
'x'
)
# We can't use zeros_like(x[0,::]) as this don't allow to test with
# 0 shape
if
dtypeBias
==
'float32'
:
z
=
T
.
nnet
.
softmax_with_bias
(
x
,
T
.
arange
(
x
.
shape
[
1
]
*
2
,
dtype
=
'float32'
)[::
2
])
elif
dtypeBias
==
'float64'
:
z
=
T
.
nnet
.
softmax_with_bias
(
x
,
T
.
arange
(
x
.
shape
[
1
]
*
2
,
dtype
=
'float64'
)[::
2
])
f
=
theano
.
function
([
x
],
z
,
mode
=
mode_without_gpu
)
f_gpu
=
theano
.
function
([
x
],
z
,
mode
=
mode_with_gpu
)
assert
f
.
maker
.
fgraph
.
toposort
()[
-
1
]
.
op
==
T
.
nnet
.
softmax_with_bias
assert
isinstance
(
f_gpu
.
maker
.
fgraph
.
toposort
()[
-
2
]
.
op
,
theano
.
sandbox
.
gpuarray
.
nnet
.
GpuSoftmaxWithBias
)
def
cmp
(
n
,
m
):
#print "test_softmax",n,m
if
dtypeInput
==
'float32'
:
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float32'
)
.
reshape
(
n
,
m
)
elif
dtypeInput
==
'float64'
:
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float64'
)
.
reshape
(
n
,
m
)
out
=
f
(
data
)
gout
=
f_gpu
(
data
)
assert
numpy
.
allclose
(
out
,
gout
),
numpy
.
absolute
(
out
-
gout
)
cmp
(
2
,
5
)
#we need to test n>32*1024 to check that we make the block loop.
cmp
(
2
<<
15
,
5
)
cmp
(
4074
,
400
)
cmp
(
0
,
10
)
cmp
(
784
,
784
)
cmp
(
4
,
1000
)
cmp
(
4
,
1024
)
cmp
(
4
,
2000
)
cmp
(
4
,
2024
)
#GTX285 don't have enough shared mem for this case.
cmp
(
4
,
4074
)
# The GTX580, 680 and kepler don't have enough shared memory.
cmp
(
2
,
10000
)
cmp
(
128
,
16
*
1024
)
cmp
(
128
,
64
*
1024
)
def
test_softmax_float32
():
softmax_unittest_template
(
'float32'
)
def
test_softmax_float64
():
softmax_unittest_template
(
'float64'
)
def
softmax_unittest_template
(
dtypeInput
):
"""
This is basic test for GpuSoftmax with float64 variables
We check that we loop when their is too much block
We use slower code when there isn't enough shared memory
"""
assert
dtypeInput
in
[
'float32'
,
'float64'
]
if
dtypeInput
==
'float32'
:
x
=
T
.
fmatrix
(
'x'
)
elif
dtypeInput
==
'float64'
:
x
=
T
.
dmatrix
(
'x'
)
z
=
T
.
nnet
.
softmax
(
x
)
f
=
theano
.
function
([
x
],
z
,
mode
=
mode_without_gpu
)
f_gpu
=
theano
.
function
([
x
],
z
,
mode
=
mode_with_gpu
)
assert
f
.
maker
.
fgraph
.
toposort
()[
-
1
]
.
op
==
T
.
nnet
.
softmax
assert
isinstance
(
f_gpu
.
maker
.
fgraph
.
toposort
()[
-
2
]
.
op
,
theano
.
sandbox
.
gpuarray
.
nnet
.
GpuSoftmax
)
def
cmp
(
n
,
m
):
if
dtypeInput
==
'float32'
:
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float32'
)
.
reshape
(
n
,
m
)
elif
dtypeInput
==
'float64'
:
data
=
numpy
.
arange
(
n
*
m
,
dtype
=
'float64'
)
.
reshape
(
n
,
m
)
out
=
f
(
data
)
gout
=
f_gpu
(
data
)
assert
numpy
.
allclose
(
out
,
gout
),
numpy
.
absolute
(
out
-
gout
)
#we need to test n>32*1024 to check that we make the block loop.
cmp
(
2
,
5
)
cmp
(
2
<<
15
,
5
)
cmp
(
4074
,
400
)
cmp
(
0
,
10
)
cmp
(
784
,
784
)
cmp
(
4
,
1000
)
cmp
(
4
,
1024
)
cmp
(
4
,
2000
)
cmp
(
4
,
2024
)
# The GTX285 don't have enough shared memory.
cmp
(
4
,
4074
)
# The GTX580, 680 and kepler don't have enough shared memory.
cmp
(
2
,
10000
)
cmp
(
128
,
16
*
1024
)
cmp
(
128
,
64
*
1024
)
\ No newline at end of file
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论