Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
42869d83
提交
42869d83
authored
7月 19, 2016
作者:
Arnaud Bergeron
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Port the new cudnn batchnorm op.
上级
cb6c5b9c
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
427 行增加
和
0 行删除
+427
-0
dnn.py
theano/gpuarray/dnn.py
+177
-0
dnn_batchnorm.c
theano/gpuarray/dnn_batchnorm.c
+62
-0
dnn_batchnorm_base.c
theano/gpuarray/dnn_batchnorm_base.c
+40
-0
dnn_batchnorm_grad.c
theano/gpuarray/dnn_batchnorm_grad.c
+93
-0
dnn_batchnorm_inf.c
theano/gpuarray/dnn_batchnorm_inf.c
+55
-0
没有找到文件。
theano/gpuarray/dnn.py
浏览文件 @
42869d83
...
@@ -1427,6 +1427,183 @@ class GpuDnnSoftmaxGrad(GpuDnnSoftmaxBase):
...
@@ -1427,6 +1427,183 @@ class GpuDnnSoftmaxGrad(GpuDnnSoftmaxBase):
return
Apply
(
self
,
[
dy
,
sm
],
[
sm
.
type
()])
return
Apply
(
self
,
[
dy
,
sm
],
[
sm
.
type
()])
class
GpuDnnBatchNorm
(
DnnBase
):
"""
Base Op for cuDNN Batch Normalization.
Parameters
----------
mode : {'per-activation', 'spatial'}
Whether to normalize per activation (in this mode, bias and scale
tensor dimensions are 1xCxHxW) or share normalization factors across
spatial dimensions (in this mode, bias and scale tensor dimensions
are 1xCx1x1).
epsilon
Epsilon value used in the batch normalization formula. Minimum allowed
value is 1e-5 (imposed by cuDNN).
"""
__props__
=
(
'mode'
,
'epsilon'
)
def
__init__
(
self
,
mode
=
'per-activation'
,
epsilon
=
1e-4
):
DnnBase
.
__init__
(
self
,
[
'dnn_batchnorm_base.c'
,
'dnn_batchnorm.c'
],
'dnn_batchnorm_op'
)
if
version
()
<
5000
:
raise
RuntimeError
(
"cuDNN Batch Normalization requires cuDNN v5 or later"
)
assert
(
mode
in
(
'per-activation'
,
'spatial'
))
self
.
mode
=
mode
assert
(
epsilon
>=
1e-5
)
self
.
epsilon
=
epsilon
def
get_op_params
(
self
):
params
=
[]
params
.
append
((
'MODE'
,
(
"CUDNN_BATCHNORM_SPATIAL"
if
self
.
mode
==
"spatial"
else
"CUDNN_BATCHNORM_PER_ACTIVATION"
)))
params
.
append
((
'EPSILON'
,
str
(
self
.
epsilon
)))
return
params
def
infer_shape
(
self
,
node
,
shape
):
return
[
shape
[
0
],
shape
[
1
],
shape
[
1
]]
def
make_node
(
self
,
x
,
scale
,
bias
):
ctx_name
=
infer_context_name
(
x
,
scale
,
bias
)
x
=
as_gpuarray_variable
(
x
,
ctx_name
)
scale
=
as_gpuarray_variable
(
scale
,
ctx_name
)
bias
=
as_gpuarray_variable
(
bias
,
ctx_name
)
assert
x
.
ndim
==
4
assert
scale
.
ndim
==
4
assert
bias
.
ndim
==
4
return
Apply
(
self
,
[
x
,
scale
,
bias
],
[
x
.
type
(),
scale
.
type
(),
scale
.
type
()])
def
grad
(
self
,
inputs
,
grads
):
x
,
scale
,
bias
=
inputs
dy
=
grads
[
0
]
_
,
x_mean
,
x_invstd
=
self
.
make_node
(
x
,
scale
,
bias
)
.
outputs
return
GpuDnnBatchNormGrad
(
self
.
mode
,
self
.
epsilon
)(
x
,
dy
,
scale
,
x_mean
,
x_invstd
)
class
GpuDnnBatchNormInference
(
DnnBase
):
"""
Base Op for cuDNN Batch Normalization.
Parameters
----------
mode : {'per-activation', 'spatial'}
Whether to normalize per activation (in this mode, bias and scale
tensor dimensions are 1xCxHxW) or share normalization factors across
spatial dimensions (in this mode, bias and scale tensor dimensions
are 1xCx1x1).
epsilon
Epsilon value used in the batch normalization formula. Minimum allowed
value is 1e-5 (imposed by cuDNN).
"""
__props__
=
(
'mode'
,
'epsilon'
)
def
__init__
(
self
,
mode
=
'per-activation'
,
epsilon
=
1e-4
):
DnnBase
.
__init__
(
self
,
[
'dnn_batchnorm_base.c'
,
'dnn_batchnorm_inf.c'
],
'dnn_batchnorm_op'
)
if
version
()
<
5000
:
raise
RuntimeError
(
"cuDNN Batch Normalization requires cuDNN v5 or later"
)
assert
(
mode
in
(
'per-activation'
,
'spatial'
))
self
.
mode
=
mode
assert
(
epsilon
>=
1e-5
)
self
.
epsilon
=
epsilon
def
get_op_params
(
self
):
params
=
[]
params
.
append
((
'MODE'
,
(
"CUDNN_BATCHNORM_SPATIAL"
if
self
.
mode
==
"spatial"
else
"CUDNN_BATCHNORM_PER_ACTIVATION"
)))
params
.
append
((
'EPSILON'
,
str
(
self
.
epsilon
)))
return
params
def
infer_shape
(
self
,
node
,
shape
):
return
[
shape
[
0
]]
def
make_node
(
self
,
x
,
scale
,
bias
,
estimated_mean
,
estimated_variance
):
ctx_name
=
infer_context_name
(
x
,
scale
,
bias
,
estimated_mean
,
estimated_variance
)
x
=
as_gpuarray_variable
(
x
,
ctx_name
)
scale
=
as_gpuarray_variable
(
scale
,
ctx_name
)
bias
=
as_gpuarray_variable
(
bias
,
ctx_name
)
estimated_mean
=
as_gpuarray_variable
(
estimated_mean
,
ctx_name
)
estimated_variance
=
as_gpuarray_variable
(
estimated_variance
,
ctx_name
)
assert
x
.
ndim
==
4
assert
scale
.
ndim
==
4
assert
bias
.
ndim
==
4
assert
estimated_mean
.
ndim
==
4
assert
estimated_variance
.
ndim
==
4
return
Apply
(
self
,
[
x
,
scale
,
bias
,
estimated_mean
,
estimated_variance
],
[
x
.
type
()])
def
grad
(
self
,
inputs
,
grads
):
x
,
scale
,
bias
,
est_mean
,
est_var
=
inputs
dy
=
grads
[
0
]
if
self
.
mode
==
"per-activation"
:
axes
=
(
0
,)
elif
self
.
mode
==
"spatial"
:
axes
=
(
0
,
2
,
3
)
scale
,
bias
,
est_mean
,
est_var
=
(
theano
.
tensor
.
addbroadcast
(
t
,
*
axes
)
for
t
in
(
scale
,
bias
,
est_mean
,
est_var
))
# define helper expressions
est_var_eps
=
est_var
+
self
.
epsilon
est_std
=
theano
.
tensor
.
sqrt
(
est_var_eps
)
two
=
theano
.
tensor
.
constant
(
2.
)
# define and return gradients
dx
=
dy
*
(
scale
/
est_std
)
dscale
=
(
dy
*
(
x
-
est_mean
))
.
sum
(
axes
,
keepdims
=
True
)
/
est_std
dbias
=
dy
.
sum
(
axes
,
keepdims
=
True
)
dmean
=
-
dy
.
sum
(
axes
,
keepdims
=
True
)
*
(
scale
/
est_std
)
dvar
=
-
(
dy
*
(
x
-
est_mean
))
.
sum
(
axes
,
keepdims
=
True
)
*
(
scale
/
(
two
*
est_var_eps
*
est_std
))
return
[
dx
,
dscale
,
dbias
,
dmean
,
dvar
]
class
GpuDnnBatchNormGrad
(
DnnBase
):
__props__
=
(
'mode'
,
'epsilon'
)
def
__init__
(
self
,
mode
=
'per-activation'
,
epsilon
=
1e-4
):
DnnBase
.
__init__
(
self
,
[
'dnn_batchnorm_base.c'
,
'dnn_batchnorm_grad.c'
],
'dnn_batchnorm_grad'
)
if
version
()
<
5000
:
raise
RuntimeError
(
"cuDNN Batch Normalization requires cuDNN v5 or later"
)
assert
(
mode
in
(
'per-activation'
,
'spatial'
))
self
.
mode
=
mode
assert
(
epsilon
>=
1e-5
)
self
.
epsilon
=
epsilon
def
get_op_params
(
self
):
params
=
[]
params
.
append
((
'MODE'
,
(
"CUDNN_BATCHNORM_SPATIAL"
if
self
.
mode
==
"spatial"
else
"CUDNN_BATCHNORM_PER_ACTIVATION"
)))
params
.
append
((
'EPSILON'
,
str
(
self
.
epsilon
)))
return
params
def
make_node
(
self
,
x
,
dy
,
scale
,
x_mean
,
x_invstd
):
ctx_name
=
infer_context_name
(
x
,
dy
,
scale
,
x_mean
,
x_invstd
)
x
=
as_gpuarray_variable
(
x
,
ctx_name
)
dy
=
as_gpuarray_variable
(
dy
,
ctx_name
)
scale
=
as_gpuarray_variable
(
scale
,
ctx_name
)
x_mean
=
as_gpuarray_variable
(
x_mean
,
ctx_name
)
x_invstd
=
as_gpuarray_variable
(
x_invstd
,
ctx_name
)
assert
x
.
ndim
==
4
and
dy
.
ndim
==
4
and
scale
.
ndim
==
4
and
x_mean
.
ndim
==
4
and
x_invstd
.
ndim
==
4
return
Apply
(
self
,
[
x
,
dy
,
scale
,
x_mean
,
x_invstd
],
[
x
.
type
(),
scale
.
type
(),
scale
.
type
()])
def
infer_shape
(
self
,
node
,
shape
):
return
[
shape
[
0
],
shape
[
2
],
shape
[
2
]]
@register_opt2
([
AbstractConv2d
,
AbstractConv2d_gradWeights
,
@register_opt2
([
AbstractConv2d
,
AbstractConv2d_gradWeights
,
AbstractConv2d_gradInputs
],
'fast_compile'
,
'conv_dnn'
,
'cudnn'
)
AbstractConv2d_gradInputs
],
'fast_compile'
,
'conv_dnn'
,
'cudnn'
)
def
local_abstractconv_cudnn_graph
(
op
,
context_name
,
inputs
,
outputs
):
def
local_abstractconv_cudnn_graph
(
op
,
context_name
,
inputs
,
outputs
):
...
...
theano/gpuarray/dnn_batchnorm.c
0 → 100644
浏览文件 @
42869d83
#section support_code_struct
int
dnn_batchnorm_op
(
PyGpuArrayObject
*
inp
,
PyGpuArrayObject
*
scale
,
PyGpuArrayObject
*
bias
,
PyGpuArrayObject
**
outp
,
PyGpuArrayObject
**
x_mean
,
PyGpuArrayObject
**
x_invstd
,
PyGpuContextObject
*
c
)
{
if
(
c_set_tensorNd
(
inp
,
bn_input
)
!=
0
)
return
1
;
if
(
c_set_tensorNd
(
scale
,
bn_params
)
!=
0
)
return
1
;
if
(
theano_prep_output
(
outp
,
inp
->
ga
.
nd
,
inp
->
ga
.
dimensions
,
inp
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
theano_prep_output
(
x_mean
,
scale
->
ga
.
nd
,
scale
->
ga
.
dimensions
,
scale
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
theano_prep_output
(
x_invstd
,
scale
->
ga
.
nd
,
scale
->
ga
.
dimensions
,
scale
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
c_set_tensorNd
(
*
outp
,
bn_output
)
!=
0
)
return
1
;
{
const
float
falpha
=
1
.;
const
float
fbeta
=
0
.;
const
double
dalpha
=
1
.;
const
double
dbeta
=
0
.;
void
*
alpha
;
void
*
beta
;
if
(
inp
->
ga
.
typecode
==
GA_DOUBLE
)
{
alpha
=
(
void
*
)
&
dalpha
;
beta
=
(
void
*
)
&
dbeta
;
}
else
{
alpha
=
(
void
*
)
&
falpha
;
beta
=
(
void
*
)
&
fbeta
;
}
cudnnStatus_t
err
=
cudnnBatchNormalizationForwardTraining
(
APPLY_SPECIFIC
(
_handle
),
MODE
,
alpha
,
beta
,
bn_input
,
PyGpuArray_DEV_DATA
(
inp
),
bn_output
,
PyGpuArray_DEV_DATA
(
*
outp
),
bn_params
,
PyGpuArray_DEV_DATA
(
scale
),
PyGpuArray_DEV_DATA
(
bias
),
0
,
NULL
,
// running mean, deliberately unused
NULL
,
// running var, deliberately unused
EPSILON
,
PyGpuArray_DEV_DATA
(
*
x_mean
),
PyGpuArray_DEV_DATA
(
*
x_invstd
)
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error during batchnorm: %s
\n
"
,
cudnnGetErrorString
(
err
));
return
1
;
}
}
return
0
;
}
theano/gpuarray/dnn_batchnorm_base.c
0 → 100644
浏览文件 @
42869d83
#section init_code_struct
{
cudnnStatus_t
err
;
bn_input
=
NULL
;
bn_params
=
NULL
;
bn_output
=
NULL
;
if
((
err
=
cudnnCreateTensorDescriptor
(
&
bn_input
))
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_MemoryError
,
"could not allocate tensor descriptor "
"(bn_input): %s"
,
cudnnGetErrorString
(
err
));
FAIL
;
}
if
((
err
=
cudnnCreateTensorDescriptor
(
&
bn_params
))
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_MemoryError
,
"could not allocate tensor descriptor "
"(bn_params): %s"
,
cudnnGetErrorString
(
err
));
FAIL
;
}
if
((
err
=
cudnnCreateTensorDescriptor
(
&
bn_output
))
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_MemoryError
,
"could not allocate tensor descriptor "
"(bn_output): %s"
,
cudnnGetErrorString
(
err
));
FAIL
;
}
}
#section cleanup_code_struct
if
(
bn_input
!=
NULL
)
cudnnDestroyTensorDescriptor
(
bn_input
);
if
(
bn_params
!=
NULL
)
cudnnDestroyTensorDescriptor
(
bn_params
);
if
(
bn_output
!=
NULL
)
cudnnDestroyTensorDescriptor
(
bn_output
);
#section support_code_struct
cudnnTensorDescriptor_t
bn_input
;
cudnnTensorDescriptor_t
bn_params
;
cudnnTensorDescriptor_t
bn_output
;
theano/gpuarray/dnn_batchnorm_grad.c
0 → 100644
浏览文件 @
42869d83
#section init_code_struct
{
cudnnStatus_t
err
;
bn_doutput
=
NULL
;
if
((
err
=
cudnnCreateTensorDescriptor
(
&
bn_doutput
))
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_MemoryError
,
"could not allocate tensor descriptor "
"(bn_doutput): %s"
,
cudnnGetErrorString
(
err
));
FAIL
;
}
}
#section cleanup_code_struct
if
(
bn_doutput
!=
NULL
)
cudnnDestroyTensorDescriptor
(
bn_doutput
);
#section support_code_struct
cudnnTensorDescriptor_t
bn_doutput
;
int
dnn_batchnorm_grad
(
PyGpuArrayObject
*
inp
,
PyGpuArrayObject
*
doutp
,
PyGpuArrayObject
*
scale
,
PyGpuArrayObject
*
x_mean
,
PyGpuArrayObject
*
x_invstd
,
PyGpuArrayObject
**
dinp
,
PyGpuArrayObject
**
dscale
,
PyGpuArrayObject
**
dbias
,
PyGpuContextObject
*
c
)
{
if
(
c_set_tensorNd
(
inp
,
bn_input
)
!=
0
)
return
1
;
if
(
c_set_tensorNd
(
doutp
,
bn_doutput
)
!=
0
)
return
1
;
if
(
c_set_tensorNd
(
scale
,
bn_params
)
!=
0
)
return
1
;
if
(
theano_prep_output
(
dinp
,
inp
->
ga
.
nd
,
inp
->
ga
.
dimensions
,
inp
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
theano_prep_output
(
dscale
,
scale
->
ga
.
nd
,
scale
->
ga
.
dimensions
,
scale
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
theano_prep_output
(
dbias
,
scale
->
ga
.
nd
,
scale
->
ga
.
dimensions
,
scale
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
c_set_tensorNd
(
*
dinp
,
bn_output
)
!=
0
)
return
1
;
{
const
float
falpha
=
1
.;
const
float
fbeta
=
0
.;
const
double
dalpha
=
1
.;
const
double
dbeta
=
0
.;
void
*
alphaData
;
void
*
betaData
;
void
*
alphaParam
;
void
*
betaParam
;
if
(
inp
->
ga
.
typecode
==
GA_DOUBLE
)
{
alphaData
=
(
void
*
)
&
dalpha
;
betaData
=
(
void
*
)
&
dbeta
;
alphaParam
=
(
void
*
)
&
dalpha
;
betaParam
=
(
void
*
)
&
dbeta
;
}
else
{
alphaData
=
(
void
*
)
&
falpha
;
betaData
=
(
void
*
)
&
fbeta
;
alphaParam
=
(
void
*
)
&
falpha
;
betaParam
=
(
void
*
)
&
fbeta
;
}
cudnnStatus_t
err
=
cudnnBatchNormalizationBackward
(
APPLY_SPECIFIC
(
_handle
),
MODE
,
alphaData
,
betaData
,
alphaParam
,
betaParam
,
bn_input
,
PyGpuArray_DEV_DATA
(
inp
),
bn_doutput
,
PyGpuArray_DEV_DATA
(
doutp
),
bn_output
,
PyGpuArray_DEV_DATA
(
*
dinp
),
bn_params
,
PyGpuArray_DEV_DATA
(
scale
),
PyGpuArray_DEV_DATA
(
*
dscale
),
PyGpuArray_DEV_DATA
(
*
dbias
),
EPSILON
,
PyGpuArray_DEV_DATA
(
x_mean
),
PyGpuArray_DEV_DATA
(
x_invstd
)
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error during batchnorm: %s
\n
"
,
cudnnGetErrorString
(
err
));
return
1
;
}
}
return
0
;
}
theano/gpuarray/dnn_batchnorm_inf.c
0 → 100644
浏览文件 @
42869d83
#section support_code_struct
int
dnn_batchnorm_op
(
PyGpuArrayObject
*
inp
,
PyGpuArrayObject
*
scale
,
PyGpuArrayObject
*
bias
,
PyGpuArrayObject
*
est_mean
,
PyGpuArrayObject
*
est_var
,
PyGpuArrayObject
**
outp
,
PyGpuContextObject
*
c
)
{
if
(
c_set_tensorNd
(
inp
,
bn_input
)
!=
0
)
return
1
;
if
(
c_set_tensorNd
(
scale
,
bn_params
)
!=
0
)
return
1
;
if
(
theano_prep_output
(
outp
,
inp
->
ga
.
nd
,
inp
->
ga
.
dimensions
,
inp
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
c_set_tensorNd
(
*
outp
,
bn_output
)
!=
0
)
return
1
;
{
const
float
falpha
=
1
.;
const
float
fbeta
=
0
.;
const
double
dalpha
=
1
.;
const
double
dbeta
=
0
.;
void
*
alpha
;
void
*
beta
;
if
(
inp
->
ga
.
typecode
==
GA_DOUBLE
)
{
alpha
=
(
void
*
)
&
dalpha
;
beta
=
(
void
*
)
&
dbeta
;
}
else
{
alpha
=
(
void
*
)
&
falpha
;
beta
=
(
void
*
)
&
fbeta
;
}
cudnnStatus_t
err
=
cudnnBatchNormalizationForwardInference
(
APPLY_SPECIFIC
(
_handle
),
MODE
,
alpha
,
beta
,
bn_input
,
PyGpuArray_DEV_DATA
(
inp
),
bn_output
,
PyGpuArray_DEV_DATA
(
*
outp
),
bn_params
,
PyGpuArray_DEV_DATA
(
scale
),
PyGpuArray_DEV_DATA
(
bias
),
PyGpuArray_DEV_DATA
(
est_mean
),
PyGpuArray_DEV_DATA
(
est_var
),
EPSILON
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error during batchnorm: %s
\n
"
,
cudnnGetErrorString
(
err
));
return
1
;
}
}
return
0
;
}
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论