Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
fd0e3b65
提交
fd0e3b65
authored
11月 10, 2016
作者:
Gijs van Tulder
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Inplace running mean and variance on gpuarray.
上级
186056b8
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
94 行增加
和
4 行删除
+94
-4
dnn.py
theano/gpuarray/dnn.py
+35
-3
dnn_batchnorm.c
theano/gpuarray/dnn_batchnorm.c
+13
-1
test_dnn.py
theano/gpuarray/tests/test_dnn.py
+46
-0
没有找到文件。
theano/gpuarray/dnn.py
浏览文件 @
fd0e3b65
...
...
@@ -41,7 +41,7 @@ from .elemwise import GpuElemwise
# GpuDownsampleFactorMax, GpuDownsampleFactorMaxGrad
from
.nnet
import
GpuSoftmax
from
.opt
import
(
gpu_seqopt
,
register_opt
,
pool_db
,
pool_db2
,
op_lifter
,
register_opt2
)
op_lifter
,
register_opt2
,
register_inplace
)
from
.opt_util
import
alpha_merge
,
output_merge
,
inplace_allocempty
,
pad_dims
,
unpad_dims
...
...
@@ -1666,20 +1666,32 @@ class GpuDnnBatchNorm(DnnBase):
both be None.
"""
__props__
=
(
'mode'
,
'running_averages'
)
__props__
=
(
'mode'
,
'running_averages'
,
'inplace_running_mean'
,
'inplace_running_var'
)
def
__init__
(
self
,
mode
=
'per-activation'
,
running_averages
=
False
):
def
__init__
(
self
,
mode
=
'per-activation'
,
running_averages
=
False
,
inplace_running_mean
=
False
,
inplace_running_var
=
False
):
DnnBase
.
__init__
(
self
,
[
'dnn_batchnorm_base.c'
,
'dnn_batchnorm.c'
],
'dnn_batchnorm_op'
)
assert
(
mode
in
(
'per-activation'
,
'spatial'
))
self
.
mode
=
mode
self
.
running_averages
=
running_averages
self
.
inplace_running_mean
=
inplace_running_mean
self
.
inplace_running_var
=
inplace_running_var
self
.
destroy_map
=
{}
if
self
.
running_averages
and
self
.
inplace_running_mean
:
self
.
destroy_map
[
3
]
=
[
5
]
if
self
.
running_averages
and
self
.
inplace_running_var
:
self
.
destroy_map
[
4
]
=
[
6
]
def
get_op_params
(
self
):
params
=
[]
if
self
.
running_averages
:
params
.
append
((
'RUNNING_AVERAGES'
,
'1'
))
if
self
.
inplace_running_mean
:
params
.
append
((
'INPLACE_RUNNING_MEAN'
,
'1'
))
if
self
.
inplace_running_var
:
params
.
append
((
'INPLACE_RUNNING_VAR'
,
'1'
))
params
.
append
((
'MODE'
,
(
"CUDNN_BATCHNORM_SPATIAL"
if
self
.
mode
==
"spatial"
else
"CUDNN_BATCHNORM_PER_ACTIVATION"
)))
...
...
@@ -3115,6 +3127,26 @@ def local_abstract_batch_norm_train_cudnn(node):
return
results
@register_inplace
()
@local_optimizer
([
GpuDnnBatchNorm
],
inplace
=
True
)
def
local_batch_norm_inplace_running_mean
(
node
):
if
isinstance
(
node
.
op
,
GpuDnnBatchNorm
)
and
node
.
op
.
running_averages
and
not
node
.
op
.
inplace_running_mean
:
return
GpuDnnBatchNorm
(
mode
=
node
.
op
.
mode
,
running_averages
=
node
.
op
.
running_averages
,
inplace_running_mean
=
True
,
inplace_running_var
=
node
.
op
.
inplace_running_var
)(
*
node
.
inputs
)
@register_inplace
()
@local_optimizer
([
GpuDnnBatchNorm
],
inplace
=
True
)
def
local_batch_norm_inplace_running_var
(
node
):
if
isinstance
(
node
.
op
,
GpuDnnBatchNorm
)
and
node
.
op
.
running_averages
and
not
node
.
op
.
inplace_running_var
:
return
GpuDnnBatchNorm
(
mode
=
node
.
op
.
mode
,
running_averages
=
node
.
op
.
running_averages
,
inplace_running_mean
=
node
.
op
.
inplace_running_mean
,
inplace_running_var
=
True
)(
*
node
.
inputs
)
@local_optimizer
([
bn
.
AbstractBatchNormTrainGrad
])
def
local_abstract_batch_norm_train_grad_cudnn
(
node
):
if
not
isinstance
(
node
.
op
,
bn
.
AbstractBatchNormTrainGrad
):
...
...
theano/gpuarray/dnn_batchnorm.c
浏览文件 @
fd0e3b65
...
...
@@ -36,16 +36,28 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale,
return
1
;
#ifdef RUNNING_AVERAGES
#ifdef INPLACE_RUNNING_MEAN
Py_XDECREF
(
out_running_mean
);
PyGpuArrayObject
*
running_mean
=
in_running_mean
;
Py_INCREF
(
running_mean
);
#else
PyGpuArrayObject
*
running_mean
=
*
out_running_mean
;
PyGpuArrayObject
*
running_var
=
*
out_running_var
;
running_mean
=
theano_try_copy
(
running_mean
,
in_running_mean
);
if
(
running_mean
==
NULL
)
{
return
1
;
}
#endif
#ifdef INPLACE_RUNNING_VAR
Py_XDECREF
(
out_running_var
);
PyGpuArrayObject
*
running_var
=
in_running_var
;
Py_INCREF
(
running_var
);
#else
PyGpuArrayObject
*
running_var
=
*
out_running_var
;
running_var
=
theano_try_copy
(
running_var
,
in_running_var
);
if
(
running_var
==
NULL
)
{
return
1
;
}
#endif
#endif
{
...
...
theano/gpuarray/tests/test_dnn.py
浏览文件 @
fd0e3b65
from
__future__
import
absolute_import
,
print_function
,
division
import
logging
from
collections
import
OrderedDict
from
nose.plugins.skip
import
SkipTest
from
nose_parameterized
import
parameterized
...
...
@@ -1531,6 +1532,51 @@ def test_dnn_batchnorm_train_without_running_averages():
f_abstract
(
X
,
Scale
,
Bias
,
Dy
)
def
test_dnn_batchnorm_train_inplace
():
# test inplace_running_mean and inplace_running_var
if
not
dnn
.
dnn_available
(
test_ctx_name
):
raise
SkipTest
(
dnn
.
dnn_available
.
msg
)
if
dnn
.
version
(
raises
=
False
)
<
5000
:
raise
SkipTest
(
"batch normalization requires cudnn v5+"
)
utt
.
seed_rng
()
x
,
scale
,
bias
=
T
.
tensor4
(
'x'
),
T
.
tensor4
(
'scale'
),
T
.
tensor4
(
'bias'
)
data_shape
=
(
5
,
10
,
30
,
25
)
param_shape
=
(
1
,
10
,
30
,
25
)
running_mean
=
gpuarray_shared_constructor
(
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
),
broadcastable
=
(
True
,
False
,
False
,
False
))
running_var
=
gpuarray_shared_constructor
(
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
),
broadcastable
=
(
True
,
False
,
False
,
False
))
# forward pass
out
,
x_mean
,
x_invstd
,
new_running_mean
,
new_running_var
=
\
dnn
.
dnn_batch_normalization_train
(
x
,
scale
,
bias
,
'per-activation'
,
epsilon
=
5e-3
,
running_average_factor
=
0.3
,
running_mean
=
running_mean
,
running_var
=
running_var
)
# update running averages
updates
=
OrderedDict
()
updates
[
running_mean
]
=
new_running_mean
updates
[
running_var
]
=
new_running_var
# compile
f
=
theano
.
function
([
x
,
scale
,
bias
],
[
out
,
x_mean
,
x_invstd
],
updates
=
updates
,
mode
=
mode_with_gpu
)
# check for the inplace settings
nodes
=
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
n
.
op
,
dnn
.
GpuDnnBatchNorm
)]
assert
len
(
nodes
)
==
1
assert
nodes
[
0
]
.
op
.
inplace_running_mean
assert
nodes
[
0
]
.
op
.
inplace_running_var
# run
X
=
4
+
3
*
numpy
.
random
.
randn
(
*
data_shape
)
.
astype
(
theano
.
config
.
floatX
)
Scale
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
Bias
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
f
(
X
,
Scale
,
Bias
)
def
test_batchnorm_inference
():
if
not
dnn
.
dnn_available
(
test_ctx_name
):
raise
SkipTest
(
dnn
.
dnn_available
.
msg
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论