Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
91b67a6d
提交
91b67a6d
authored
4月 06, 2010
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added GpuSoftmaxWithBias to make a real speed up to the tutorial logistic_cg
上级
80fc79ec
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
125 行增加
和
1 行删除
+125
-1
nnet.py
theano/sandbox/cuda/nnet.py
+100
-0
opt.py
theano/sandbox/cuda/opt.py
+13
-1
test_opt.py
theano/sandbox/cuda/tests/test_opt.py
+12
-0
没有找到文件。
theano/sandbox/cuda/nnet.py
浏览文件 @
91b67a6d
...
@@ -368,3 +368,103 @@ class GpuSoftmax (Op):
...
@@ -368,3 +368,103 @@ class GpuSoftmax (Op):
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
),
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
),
"sm[blockIdx.x * N + threadIdx.x] = buf[threadIdx.x]"
"sm[blockIdx.x * N + threadIdx.x] = buf[threadIdx.x]"
])
])
class
GpuSoftmaxWithBias
(
Op
):
"""Writeme"""
nin
=
2
nout
=
1
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
__str__
(
self
):
return
self
.
__class__
.
__name__
def
make_node
(
self
,
x
,
b
):
return
Apply
(
self
,
[
x
,
b
],[
x
.
type
()])
def
infer_shape
(
self
,
node
,
shape
):
return
[
shape
[
0
]]
def
c_code_cache_version
(
self
):
#return ()
return
(
1
,)
+
inline_softmax
.
code_version
def
c_code
(
self
,
node
,
nodename
,
(
x
,
b
),
(
z
,),
sub
):
fail
=
sub
[
'fail'
]
return
"""
if (
%(x)
s->nd != 2)
{
PyErr_SetString(PyExc_ValueError, "rank error input");
%(fail)
s;
}
if (
%(b)
s->nd != 1)
{
PyErr_SetString(PyExc_ValueError, "rank error for the bias");
%(fail)
s;
}
if ((CudaNdarray_HOST_DIMS(
%(x)
s)[1] != CudaNdarray_HOST_DIMS(
%(b)
s)[0]))
{
PyErr_Format(PyExc_ValueError, "number of columns in x (
%%
ld) does not match length of b (
%%
ld)",
(long int)CudaNdarray_HOST_DIMS(
%(x)
s)[1], (long int)CudaNdarray_HOST_DIMS(
%(b)
s)[0]);
%(fail)
s;
}
if ((NULL ==
%(z)
s)
|| (CudaNdarray_HOST_DIMS(
%(z)
s)[0] != CudaNdarray_HOST_DIMS(
%(x)
s)[0])
|| (CudaNdarray_HOST_DIMS(
%(z)
s)[1] != CudaNdarray_HOST_DIMS(
%(x)
s)[1]))
{
Py_XDECREF(
%(z)
s);
%(z)
s = (CudaNdarray*)CudaNdarray_new_null();
if ((NULL ==
%(z)
s)
|| CudaNdarray_alloc_contiguous(
%(z)
s, 2, CudaNdarray_HOST_DIMS(
%(x)
s)))
{
Py_XDECREF(
%(z)
s);
%(z)
s = NULL;
%(fail)
s;
}
}
{
kSoftmaxWithBias_
%(nodename)
s
<<<
// todo: cap these at the card limits, implement loops in kernel
CudaNdarray_HOST_DIMS(
%(x)
s)[0],
CudaNdarray_HOST_DIMS(
%(x)
s)[1],
CudaNdarray_HOST_DIMS(
%(x)
s)[1] * 2 * sizeof(float)
>>>(
CudaNdarray_HOST_DIMS(
%(x)
s)[0],
CudaNdarray_HOST_DIMS(
%(x)
s)[1],
CudaNdarray_DEV_DATA(
%(x)
s),
CudaNdarray_HOST_STRIDES(
%(x)
s)[0],
CudaNdarray_HOST_STRIDES(
%(x)
s)[1],
CudaNdarray_DEV_DATA(
%(b)
s),
CudaNdarray_HOST_STRIDES(
%(b)
s)[0],
CudaNdarray_DEV_DATA(
%(z)
s) //guarantee c contig
);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s.
\\
n", "kSoftmax_
%(nodename)
s", cudaGetErrorString(err));
%(fail)
s;
}
}
assert(
%(z)
s);
"""
%
locals
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
return
nvcc_kernel
(
"kSoftmaxWithBias_
%
s"
%
nodename
,
params
=
[
'int M'
,
'int N'
,
'const float * x'
,
'const int sx0'
,
'const int sx1'
,
'const float * b'
,
'const int sb0'
,
'float * sm'
],
body
=
[
"extern __shared__ float buf[]"
,
"float * buf2 = buf + N"
,
"buf[threadIdx.x] = x[blockIdx.x * sx0 + threadIdx.x * sx1]"
,
"buf[threadIdx.x] += b[threadIdx.x * sb0]"
,
"buf2[threadIdx.x] = buf[threadIdx.x]"
,
"__syncthreads()"
,
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
),
"sm[blockIdx.x * N + threadIdx.x] = buf[threadIdx.x]"
])
theano/sandbox/cuda/opt.py
浏览文件 @
91b67a6d
...
@@ -11,7 +11,7 @@ from theano.sandbox.cuda.blas import GpuDownsampleFactorMax, GpuDownsampleFactor
...
@@ -11,7 +11,7 @@ from theano.sandbox.cuda.blas import GpuDownsampleFactorMax, GpuDownsampleFactor
from
theano.sandbox.cuda.nnet
import
(
from
theano.sandbox.cuda.nnet
import
(
GpuCrossentropySoftmaxArgmax1HotWithBias
,
GpuCrossentropySoftmaxArgmax1HotWithBias
,
GpuCrossentropySoftmax1HotWithBiasDx
,
GpuCrossentropySoftmax1HotWithBiasDx
,
GpuSoftmax
)
GpuSoftmax
,
GpuSoftmaxWithBias
)
from
theano.compile
import
optdb
from
theano.compile
import
optdb
#optdb.print_summary() # this shows what is currently registered (in a so-far crude way...)
#optdb.print_summary() # this shows what is currently registered (in a so-far crude way...)
...
@@ -386,6 +386,18 @@ def local_gpu_softmax(node):
...
@@ -386,6 +386,18 @@ def local_gpu_softmax(node):
return
[
host_from_gpu
(
gpu_sm
)]
return
[
host_from_gpu
(
gpu_sm
)]
return
False
return
False
@register_opt
()
@local_optimizer
([])
def
local_gpu_softmax_with_bias
(
node
):
if
isinstance
(
node
.
op
,
tensor
.
nnet
.
SoftmaxWithBias
):
x
,
b
=
node
.
inputs
x_on_gpu
=
x
.
owner
and
x
.
owner
.
op
==
host_from_gpu
b_on_gpu
=
b
.
owner
and
b
.
owner
.
op
==
host_from_gpu
if
x_on_gpu
or
b_on_gpu
:
gpu_sm
=
GpuSoftmaxWithBias
()(
gpu_from_host
(
x
),
gpu_from_host
(
b
))
return
[
host_from_gpu
(
gpu_sm
)]
return
False
#### Convolution, maxpooling
#### Convolution, maxpooling
from
theano.tensor.nnet
import
conv
from
theano.tensor.nnet
import
conv
@register_opt
()
@register_opt
()
...
...
theano/sandbox/cuda/tests/test_opt.py
浏览文件 @
91b67a6d
...
@@ -16,8 +16,10 @@ from theano.sandbox.cuda.type import CudaNdarrayType
...
@@ -16,8 +16,10 @@ from theano.sandbox.cuda.type import CudaNdarrayType
if
theano
.
config
.
mode
==
'FAST_COMPILE'
:
if
theano
.
config
.
mode
==
'FAST_COMPILE'
:
mode_with_gpu
=
theano
.
compile
.
mode
.
get_mode
(
'FAST_RUN'
)
.
including
(
'gpu'
)
mode_with_gpu
=
theano
.
compile
.
mode
.
get_mode
(
'FAST_RUN'
)
.
including
(
'gpu'
)
mode_without_gpu
=
theano
.
compile
.
mode
.
get_mode
(
'FAST_RUN'
)
.
excluding
(
'gpu'
)
else
:
else
:
mode_with_gpu
=
theano
.
compile
.
mode
.
get_default_mode
()
.
including
(
'gpu'
)
mode_with_gpu
=
theano
.
compile
.
mode
.
get_default_mode
()
.
including
(
'gpu'
)
mode_without_gpu
=
theano
.
compile
.
mode
.
get_default_mode
()
.
excluding
(
'gpu'
)
import
theano.sandbox.cuda
as
cuda
import
theano.sandbox.cuda
as
cuda
...
@@ -49,3 +51,13 @@ def test_int_pow():
...
@@ -49,3 +51,13 @@ def test_int_pow():
#theano.printing.debugprint(f)
#theano.printing.debugprint(f)
def
test_softmax_with_bias
():
x
=
tensor
.
fmatrix
()
b
=
tensor
.
fvector
()
f
=
theano
.
function
([
x
,
b
],
tensor
.
nnet
.
nnet
.
SoftmaxWithBias
()(
x
,
b
),
mode
=
mode_with_gpu
)
f2
=
theano
.
function
([
x
,
b
],
tensor
.
nnet
.
nnet
.
SoftmaxWithBias
()(
x
,
b
),
mode
=
mode_without_gpu
)
assert
isinstance
(
f
.
maker
.
env
.
toposort
()[
2
]
.
op
,
cuda
.
nnet
.
GpuSoftmaxWithBias
)
xv
=
numpy
.
random
.
rand
(
7
,
8
)
bv
=
numpy
.
random
.
rand
(
8
)
assert
numpy
.
allclose
(
f
(
xv
,
bv
),
f2
(
xv
,
bv
))
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论