Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
186056b8
提交
186056b8
authored
11月 10, 2016
作者:
Gijs van Tulder
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Compute running_mean and running_var using cuDNN.
上级
c4293e69
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
211 行增加
和
48 行删除
+211
-48
dnn.py
theano/gpuarray/dnn.py
+145
-35
dnn_batchnorm.c
theano/gpuarray/dnn_batchnorm.c
+36
-2
test_dnn.py
theano/gpuarray/tests/test_dnn.py
+30
-11
没有找到文件。
theano/gpuarray/dnn.py
浏览文件 @
186056b8
...
@@ -1647,48 +1647,98 @@ class GpuDnnBatchNorm(DnnBase):
...
@@ -1647,48 +1647,98 @@ class GpuDnnBatchNorm(DnnBase):
epsilon
epsilon
Epsilon value used in the batch normalization formula. Minimum allowed
Epsilon value used in the batch normalization formula. Minimum allowed
value is 1e-5 (imposed by cuDNN).
value is 1e-5 (imposed by cuDNN).
running_average_factor : float
Factor for updating the values or `running_mean` and `running_var`.
If the factor is close to one, the running averages will update quickly,
if the factor is close to zero it will update slowly.
running_mean : tensor or None
Previous value of the running mean. If this is given, the new value
``running_mean * (1 - r_a_factor) + batch mean * r_a_factor``
will be returned as one of the outputs of this function.
`running_mean` and `running_var` should either both be given or
both be None.
running_var : tensor or None
Previous value of the running variance. If this is given, the new value
``running_var * (1 - r_a_factor) + (m / (m - 1)) * batch var * r_a_factor``
will be returned as one of the outputs of this function,
where `m` is the product of lengths of the averaged-over dimensions.
`running_mean` and `running_var` should either both be given or
both be None.
"""
"""
__props__
=
(
'mode'
,)
__props__
=
(
'mode'
,
'running_averages'
)
def
__init__
(
self
,
mode
=
'per-activation'
):
def
__init__
(
self
,
mode
=
'per-activation'
,
running_averages
=
False
):
DnnBase
.
__init__
(
self
,
[
'dnn_batchnorm_base.c'
,
'dnn_batchnorm.c'
],
DnnBase
.
__init__
(
self
,
[
'dnn_batchnorm_base.c'
,
'dnn_batchnorm.c'
],
'dnn_batchnorm_op'
)
'dnn_batchnorm_op'
)
assert
(
mode
in
(
'per-activation'
,
'spatial'
))
assert
(
mode
in
(
'per-activation'
,
'spatial'
))
self
.
mode
=
mode
self
.
mode
=
mode
self
.
running_averages
=
running_averages
def
get_op_params
(
self
):
def
get_op_params
(
self
):
params
=
[]
params
=
[]
if
self
.
running_averages
:
params
.
append
((
'RUNNING_AVERAGES'
,
'1'
))
params
.
append
((
'MODE'
,
(
"CUDNN_BATCHNORM_SPATIAL"
params
.
append
((
'MODE'
,
(
"CUDNN_BATCHNORM_SPATIAL"
if
self
.
mode
==
"spatial"
if
self
.
mode
==
"spatial"
else
"CUDNN_BATCHNORM_PER_ACTIVATION"
)))
else
"CUDNN_BATCHNORM_PER_ACTIVATION"
)))
return
params
return
params
def
infer_shape
(
self
,
node
,
shape
):
def
infer_shape
(
self
,
node
,
shape
):
return
[
shape
[
0
]
,
shape
[
1
],
shape
[
1
]]
return
[
shape
[
0
]
]
+
[
shape
[
1
]]
*
(
len
(
node
.
outputs
)
-
1
)
def
make_node
(
self
,
x
,
scale
,
bias
,
epsilon
=
1e-4
):
def
make_node
(
self
,
x
,
scale
,
bias
,
epsilon
=
1e-4
,
running_average_factor
=
0.1
,
running_mean
=
None
,
running_var
=
None
):
assert
x
.
ndim
==
scale
.
ndim
==
bias
.
ndim
assert
x
.
ndim
in
(
4
,
5
)
assert
self
.
running_averages
==
(
running_mean
is
not
None
)
==
(
running_var
is
not
None
)
assert
(
running_mean
is
None
or
running_mean
.
ndim
==
x
.
ndim
)
assert
(
running_var
is
None
or
running_var
.
ndim
==
x
.
ndim
)
ctx_name
=
infer_context_name
(
x
,
scale
,
bias
)
ctx_name
=
infer_context_name
(
x
,
scale
,
bias
)
x
=
as_gpuarray_variable
(
x
,
ctx_name
)
x
=
as_gpuarray_variable
(
x
,
ctx_name
)
scale
=
as_gpuarray_variable
(
scale
,
ctx_name
)
scale
=
as_gpuarray_variable
(
scale
,
ctx_name
)
bias
=
as_gpuarray_variable
(
bias
,
ctx_name
)
bias
=
as_gpuarray_variable
(
bias
,
ctx_name
)
epsilon
=
as_scalar
(
epsilon
)
.
astype
(
'float64'
)
epsilon
=
as_scalar
(
epsilon
)
.
astype
(
'float64'
)
assert
x
.
ndim
==
scale
.
ndim
==
bias
.
ndim
running_average_factor
=
as_scalar
(
running_average_factor
)
.
astype
(
'float64'
)
assert
x
.
ndim
in
(
4
,
5
)
inputs
=
[
x
,
scale
,
bias
,
epsilon
,
running_average_factor
]
return
Apply
(
self
,
[
x
,
scale
,
bias
,
epsilon
],
[
x
.
type
(),
scale
.
type
(),
scale
.
type
()])
output_types
=
[
x
.
type
(),
scale
.
type
(),
scale
.
type
()]
if
running_mean
is
not
None
and
running_var
is
not
None
:
inputs
.
append
(
as_gpuarray_variable
(
running_mean
,
ctx_name
))
inputs
.
append
(
as_gpuarray_variable
(
running_var
,
ctx_name
))
output_types
.
append
(
scale
.
type
())
output_types
.
append
(
scale
.
type
())
return
Apply
(
self
,
inputs
,
output_types
)
def
grad
(
self
,
inputs
,
grads
):
def
grad
(
self
,
inputs
,
grads
):
x
,
scale
,
bias
,
epsilon
=
inputs
x
,
scale
,
bias
,
epsilon
,
running_average_factor
=
inputs
[:
5
]
dy
=
grads
[
0
]
dy
=
grads
[
0
]
_
,
x_mean
,
x_invstd
=
self
(
x
,
scale
,
bias
,
epsilon
)
_
,
x_mean
,
x_invstd
=
self
(
*
inputs
)[:
3
]
return
GpuDnnBatchNormGrad
(
self
.
mode
)(
x
,
dy
,
scale
,
x_mean
,
disconnected_outputs
=
[
x_invstd
,
epsilon
)
+
[
DisconnectedType
()()]
DisconnectedType
()(),
# epsilon
DisconnectedType
()()]
# running_average_factor
# Optional running_mean and running_var.
for
i
in
range
(
5
,
len
(
inputs
)):
disconnected_outputs
.
append
(
DisconnectedType
()())
return
GpuDnnBatchNormGrad
(
self
.
mode
)(
x
,
dy
,
scale
,
x_mean
,
x_invstd
,
epsilon
)
+
disconnected_outputs
def
connection_pattern
(
self
,
node
):
def
connection_pattern
(
self
,
node
):
# Specificy that epsilon is not connected to outputs.
# Specificy that epsilon and running_average_factor are not connected to outputs.
return
[[
True
,
True
,
True
],
[
True
,
True
,
True
],
[
True
,
True
,
True
],
patterns
=
[[
True
,
True
,
True
],
# x
[
False
,
False
,
False
]]
[
True
,
True
,
True
],
# scale
[
True
,
True
,
True
],
# bias
[
False
,
False
,
False
],
# epsilon
[
False
,
False
,
False
]]
# running_average_factor
# Optional running_mean and running_var are only
# connected to their new values.
for
i
in
range
(
5
,
len
(
node
.
inputs
)):
patterns
[
0
]
.
append
(
True
)
for
pattern
in
patterns
[
1
:]:
pattern
.
append
(
False
)
patterns
.
append
([
False
]
*
(
3
+
i
-
5
)
+
[
True
])
return
patterns
class
GpuDnnBatchNormInference
(
DnnBase
):
class
GpuDnnBatchNormInference
(
DnnBase
):
...
@@ -2405,7 +2455,8 @@ class RNNBlock(object):
...
@@ -2405,7 +2455,8 @@ class RNNBlock(object):
def
dnn_batch_normalization_train
(
inputs
,
gamma
,
beta
,
mode
=
'per-activation'
,
def
dnn_batch_normalization_train
(
inputs
,
gamma
,
beta
,
mode
=
'per-activation'
,
epsilon
=
1e-4
):
epsilon
=
1e-4
,
running_average_factor
=
0.1
,
running_mean
=
None
,
running_var
=
None
):
"""
"""
Performs batch normalization of the given inputs, using the mean and
Performs batch normalization of the given inputs, using the mean and
variance of the inputs.
variance of the inputs.
...
@@ -2425,6 +2476,23 @@ def dnn_batch_normalization_train(inputs, gamma, beta, mode='per-activation',
...
@@ -2425,6 +2476,23 @@ def dnn_batch_normalization_train(inputs, gamma, beta, mode='per-activation',
epsilon : float
epsilon : float
Epsilon value used in the batch normalization formula. Minimum allowed
Epsilon value used in the batch normalization formula. Minimum allowed
value is 1e-5 (imposed by cuDNN).
value is 1e-5 (imposed by cuDNN).
running_average_factor : float
Factor for updating the values or `running_mean` and `running_var`.
If the factor is close to one, the running averages will update quickly,
if the factor is close to zero it will update slowly.
running_mean : tensor or None
Previous value of the running mean. If this is given, the new value
``running_mean * (1 - r_a_factor) + batch mean * r_a_factor``
will be returned as one of the outputs of this function.
`running_mean` and `running_var` should either both be given or
both be None.
running_var : tensor or None
Previous value of the running variance. If this is given, the new value
``running_var * (1 - r_a_factor) + (m / (m - 1)) * batch var * r_a_factor``
will be returned as one of the outputs of this function,
where `m` is the product of lengths of the averaged-over dimensions.
`running_mean` and `running_var` should either both be given or
both be None.
Returns
Returns
-------
-------
...
@@ -2434,6 +2502,12 @@ def dnn_batch_normalization_train(inputs, gamma, beta, mode='per-activation',
...
@@ -2434,6 +2502,12 @@ def dnn_batch_normalization_train(inputs, gamma, beta, mode='per-activation',
Means of `inputs` across the normalization axes.
Means of `inputs` across the normalization axes.
invstd : tensor
invstd : tensor
Inverse standard deviations of `inputs` across the normalization axes.
Inverse standard deviations of `inputs` across the normalization axes.
new_running_mean : tensor
New value of the running mean (only if both `running_mean` and
`running_var` were given).
new_running_var : tensor
New value of the running variance (only if both `running_var` and
`running_mean` were given).
Notes
Notes
-----
-----
...
@@ -2445,9 +2519,16 @@ def dnn_batch_normalization_train(inputs, gamma, beta, mode='per-activation',
...
@@ -2445,9 +2519,16 @@ def dnn_batch_normalization_train(inputs, gamma, beta, mode='per-activation',
axes = 0 if mode == 'per-activation' else (0, 2, 3)
axes = 0 if mode == 'per-activation' else (0, 2, 3)
mean = inputs.mean(axes, keepdims=True)
mean = inputs.mean(axes, keepdims=True)
invstd = T.inv(T.sqrt(inputs.var(axes, keepdims=True) + epsilon))
var = inputs.var(axes, keepdims=True)
invstd = T.inv(T.sqrt(var + epsilon))
out = (inputs - mean) * gamma * invstd + beta
out = (inputs - mean) * gamma * invstd + beta
m = T.cast(T.prod(inputs.shape) / T.prod(mean.shape), 'float32')
running_mean = running_mean * (1 - running_average_factor) +
\\
mean * running_average_factor
running_var = running_var * (1 - running_average_factor) +
\\
(m / (m - 1)) * var * running_average_factor
For 5d tensors, the axes are (0, 2, 3, 4).
For 5d tensors, the axes are (0, 2, 3, 4).
"""
"""
ndim
=
inputs
.
ndim
ndim
=
inputs
.
ndim
...
@@ -2455,28 +2536,60 @@ def dnn_batch_normalization_train(inputs, gamma, beta, mode='per-activation',
...
@@ -2455,28 +2536,60 @@ def dnn_batch_normalization_train(inputs, gamma, beta, mode='per-activation',
raise
ValueError
(
"gamma and beta must be of the same dimensionality "
raise
ValueError
(
"gamma and beta must be of the same dimensionality "
"as inputs; got
%
d and
%
d instead of
%
d"
%
"as inputs; got
%
d and
%
d instead of
%
d"
%
(
gamma
.
ndim
,
beta
.
ndim
,
ndim
))
(
gamma
.
ndim
,
beta
.
ndim
,
ndim
))
if
(
running_mean
is
None
)
!=
(
running_var
is
None
):
raise
ValueError
(
"running_mean and running_var must either both be "
"given or both be None"
)
if
running_mean
is
not
None
and
running_mean
.
ndim
!=
ndim
:
raise
ValueError
(
"running_mean must be of the same dimensionality "
"as inputs; got
%
d instead of
%
d"
%
(
running_mean
.
ndim
,
ndim
))
if
running_var
is
not
None
and
running_var
.
ndim
!=
ndim
:
raise
ValueError
(
"running_var must be of the same dimensionality "
"as inputs; got
%
d instead of
%
d"
%
(
running_var
.
ndim
,
ndim
))
if
epsilon
<
1e-5
:
if
epsilon
<
1e-5
:
raise
ValueError
(
"epsilon must be at least 1e-5, got
%
f"
%
epsilon
)
raise
ValueError
(
"epsilon must be at least 1e-5, got
%
f"
%
epsilon
)
running_averages
=
(
running_var
is
not
None
and
running_var
is
not
None
)
if
ndim
<
4
:
if
ndim
<
4
:
inputs
=
theano
.
tensor
.
shape_padright
(
inputs
,
4
-
ndim
)
inputs
=
theano
.
tensor
.
shape_padright
(
inputs
,
4
-
ndim
)
gamma
=
theano
.
tensor
.
shape_padright
(
gamma
,
4
-
ndim
)
gamma
=
theano
.
tensor
.
shape_padright
(
gamma
,
4
-
ndim
)
beta
=
theano
.
tensor
.
shape_padright
(
beta
,
4
-
ndim
)
beta
=
theano
.
tensor
.
shape_padright
(
beta
,
4
-
ndim
)
if
running_averages
:
running_mean
=
theano
.
tensor
.
shape_padright
(
running_mean
,
4
-
ndim
)
running_var
=
theano
.
tensor
.
shape_padright
(
running_var
,
4
-
ndim
)
elif
ndim
>
5
:
elif
ndim
>
5
:
inputs_shape
=
inputs
.
shape
inputs_shape
=
inputs
.
shape
params_shape
=
gamma
.
shape
params_shape
=
gamma
.
shape
inputs
=
theano
.
tensor
.
flatten
(
inputs
,
5
)
inputs
=
theano
.
tensor
.
flatten
(
inputs
,
5
)
gamma
=
theano
.
tensor
.
flatten
(
gamma
,
5
)
gamma
=
theano
.
tensor
.
flatten
(
gamma
,
5
)
beta
=
theano
.
tensor
.
flatten
(
beta
,
5
)
beta
=
theano
.
tensor
.
flatten
(
beta
,
5
)
batchnorm_op
=
GpuDnnBatchNorm
(
mode
=
mode
)
if
running_averages
:
result
=
tuple
(
batchnorm_op
(
gpu_contiguous
(
inputs
),
gpu_contiguous
(
gamma
),
running_mean
=
theano
.
tensor
.
flatten
(
running_mean
,
5
)
gpu_contiguous
(
beta
),
epsilon
=
epsilon
))
running_var
=
theano
.
tensor
.
flatten
(
running_var
,
5
)
batchnorm_op
=
GpuDnnBatchNorm
(
mode
=
mode
,
running_averages
=
running_averages
)
if
running_averages
:
out
,
mean
,
invstd
,
new_running_mean
,
new_running_var
=
batchnorm_op
(
gpu_contiguous
(
inputs
),
gpu_contiguous
(
gamma
),
gpu_contiguous
(
beta
),
epsilon
=
epsilon
,
running_average_factor
=
running_average_factor
,
running_mean
=
gpu_contiguous
(
running_mean
),
running_var
=
gpu_contiguous
(
running_var
))
if
new_running_mean
.
broadcastable
!=
running_mean
.
broadcastable
:
new_running_mean
=
tensor
.
patternbroadcast
(
new_running_mean
,
running_mean
.
broadcastable
)
if
new_running_var
.
broadcastable
!=
running_var
.
broadcastable
:
new_running_var
=
tensor
.
patternbroadcast
(
new_running_var
,
running_var
.
broadcastable
)
result
=
(
out
,
mean
,
invstd
,
new_running_mean
,
new_running_var
)
else
:
result
=
batchnorm_op
(
gpu_contiguous
(
inputs
),
gpu_contiguous
(
gamma
),
gpu_contiguous
(
beta
),
epsilon
=
epsilon
)
if
ndim
<
4
:
if
ndim
<
4
:
result
=
tuple
(
theano
.
tensor
.
flatten
(
r
,
ndim
)
for
r
in
result
)
result
=
tuple
(
theano
.
tensor
.
flatten
(
r
,
ndim
)
for
r
in
result
)
elif
ndim
>
5
:
elif
ndim
>
5
:
result
=
(
theano
.
tensor
.
reshape
(
result
[
0
],
inputs_shape
),
result
=
(
theano
.
tensor
.
reshape
(
result
[
0
],
inputs_shape
),)
+
tuple
(
theano
.
tensor
.
reshape
(
result
[
1
],
params_shape
),
theano
.
tensor
.
reshape
(
r
,
params_shape
)
for
r
in
result
[
1
:])
theano
.
tensor
.
reshape
(
result
[
2
],
params_shape
))
return
result
return
result
...
@@ -2974,6 +3087,10 @@ def local_abstract_batch_norm_train_cudnn(node):
...
@@ -2974,6 +3087,10 @@ def local_abstract_batch_norm_train_cudnn(node):
return
None
return
None
if
eps
<
1e-5
:
if
eps
<
1e-5
:
return
None
return
None
try
:
running_average_factor
=
theano
.
tensor
.
get_scalar_constant_value
(
running_average_factor
)
except
theano
.
tensor
.
NotScalarConstantError
:
return
None
ctx
=
infer_context_name
(
*
node
.
inputs
)
ctx
=
infer_context_name
(
*
node
.
inputs
)
if
not
dnn_available
(
ctx
):
if
not
dnn_available
(
ctx
):
...
@@ -2983,19 +3100,12 @@ def local_abstract_batch_norm_train_cudnn(node):
...
@@ -2983,19 +3100,12 @@ def local_abstract_batch_norm_train_cudnn(node):
scale
=
as_gpuarray_variable
(
scale
,
context_name
=
ctx
)
scale
=
as_gpuarray_variable
(
scale
,
context_name
=
ctx
)
bias
=
as_gpuarray_variable
(
bias
,
context_name
=
ctx
)
bias
=
as_gpuarray_variable
(
bias
,
context_name
=
ctx
)
out
,
mean
,
invstd
=
dnn_batch_normalization_train
(
x
,
scale
,
bias
,
mode
,
eps
)
inputs
=
[
x
,
scale
,
bias
,
mode
,
eps
,
running_average_factor
]
if
running_mean
is
not
None
and
running_var
is
not
None
:
results
=
[
out
,
mean
,
invstd
]
inputs
.
append
(
running_mean
)
if
running_mean
is
not
None
:
inputs
.
append
(
running_var
)
running_mean
=
running_mean
*
(
1
-
running_average_factor
)
+
\
mean
*
running_average_factor
results
=
list
(
dnn_batch_normalization_train
(
*
inputs
))
results
.
append
(
running_mean
)
if
running_var
is
not
None
:
var
=
x
.
var
(
axis
=
axes
,
keepdims
=
True
)
m
=
tensor
.
cast
(
tensor
.
prod
(
x
.
shape
)
/
tensor
.
prod
(
scale
.
shape
),
theano
.
config
.
floatX
)
running_var
=
running_var
*
(
1
-
running_average_factor
)
+
\
(
m
/
(
m
-
1
))
*
var
*
running_average_factor
results
.
append
(
running_var
)
# If the original output was on CPU, we have to transfer it
# If the original output was on CPU, we have to transfer it
for
i
in
range
(
len
(
node
.
outputs
)):
for
i
in
range
(
len
(
node
.
outputs
)):
...
...
theano/gpuarray/dnn_batchnorm.c
浏览文件 @
186056b8
...
@@ -2,8 +2,19 @@
...
@@ -2,8 +2,19 @@
int
dnn_batchnorm_op
(
PyGpuArrayObject
*
inp
,
PyGpuArrayObject
*
scale
,
int
dnn_batchnorm_op
(
PyGpuArrayObject
*
inp
,
PyGpuArrayObject
*
scale
,
PyGpuArrayObject
*
bias
,
npy_float64
epsilon
,
PyGpuArrayObject
*
bias
,
npy_float64
epsilon
,
PyGpuArrayObject
**
outp
,
PyGpuArrayObject
**
x_mean
,
npy_float64
running_average_factor
,
PyGpuArrayObject
**
x_invstd
,
cudnnHandle_t
_handle
)
{
#ifdef RUNNING_AVERAGES
PyGpuArrayObject
*
in_running_mean
,
PyGpuArrayObject
*
in_running_var
,
#endif
PyGpuArrayObject
**
outp
,
PyGpuArrayObject
**
x_mean
,
PyGpuArrayObject
**
x_invstd
,
#ifdef RUNNING_AVERAGES
PyGpuArrayObject
**
out_running_mean
,
PyGpuArrayObject
**
out_running_var
,
#endif
cudnnHandle_t
_handle
)
{
PyGpuContextObject
*
c
=
inp
->
context
;
PyGpuContextObject
*
c
=
inp
->
context
;
if
(
c_set_tensorNd
(
inp
,
bn_input
)
!=
0
)
if
(
c_set_tensorNd
(
inp
,
bn_input
)
!=
0
)
...
@@ -24,6 +35,19 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale,
...
@@ -24,6 +35,19 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale,
if
(
c_set_tensorNd
(
*
outp
,
bn_output
)
!=
0
)
if
(
c_set_tensorNd
(
*
outp
,
bn_output
)
!=
0
)
return
1
;
return
1
;
#ifdef RUNNING_AVERAGES
PyGpuArrayObject
*
running_mean
=
*
out_running_mean
;
PyGpuArrayObject
*
running_var
=
*
out_running_var
;
running_mean
=
theano_try_copy
(
running_mean
,
in_running_mean
);
if
(
running_mean
==
NULL
)
{
return
1
;
}
running_var
=
theano_try_copy
(
running_var
,
in_running_var
);
if
(
running_var
==
NULL
)
{
return
1
;
}
#endif
{
{
const
float
falpha
=
1
.;
const
float
falpha
=
1
.;
const
float
fbeta
=
0
.;
const
float
fbeta
=
0
.;
...
@@ -50,9 +74,15 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale,
...
@@ -50,9 +74,15 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale,
bn_params
,
bn_params
,
PyGpuArray_DEV_DATA
(
scale
),
PyGpuArray_DEV_DATA
(
scale
),
PyGpuArray_DEV_DATA
(
bias
),
PyGpuArray_DEV_DATA
(
bias
),
#ifdef RUNNING_AVERAGES
running_average_factor
,
PyGpuArray_DEV_DATA
(
running_mean
),
PyGpuArray_DEV_DATA
(
running_var
),
#else
0
,
0
,
NULL
,
// running mean, deliberately unused
NULL
,
// running mean, deliberately unused
NULL
,
// running var, deliberately unused
NULL
,
// running var, deliberately unused
#endif
epsilon
,
epsilon
,
PyGpuArray_DEV_DATA
(
*
x_mean
),
PyGpuArray_DEV_DATA
(
*
x_mean
),
PyGpuArray_DEV_DATA
(
*
x_invstd
)
PyGpuArray_DEV_DATA
(
*
x_invstd
)
...
@@ -62,6 +92,10 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale,
...
@@ -62,6 +92,10 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale,
cudnnGetErrorString
(
err
));
cudnnGetErrorString
(
err
));
return
1
;
return
1
;
}
}
#ifdef RUNNING_AVERAGES
*
out_running_mean
=
running_mean
;
*
out_running_var
=
running_var
;
#endif
}
}
return
0
;
return
0
;
}
}
theano/gpuarray/tests/test_dnn.py
浏览文件 @
186056b8
...
@@ -1393,8 +1393,11 @@ def test_dnn_batchnorm_train():
...
@@ -1393,8 +1393,11 @@ def test_dnn_batchnorm_train():
running_average_factor
=
0.3
running_average_factor
=
0.3
# forward pass, direct interface
# forward pass, direct interface
out_gpu
,
x_mean_gpu
,
x_invstd_gpu
=
dnn
.
dnn_batch_normalization_train
(
out_gpu
,
x_mean_gpu
,
x_invstd_gpu
,
\
x
,
scale
,
bias
,
mode
,
eps
)
out_running_mean_gpu
,
out_running_var_gpu
=
\
dnn
.
dnn_batch_normalization_train
(
x
,
scale
,
bias
,
mode
,
eps
,
running_average_factor
,
running_mean
,
running_var
)
# forward pass, abstract interface
# forward pass, abstract interface
out_abstract
,
x_mean_abstract
,
x_invstd_abstract
,
\
out_abstract
,
x_mean_abstract
,
x_invstd_abstract
,
\
out_running_mean_abstract
,
out_running_var_abstract
=
\
out_running_mean_abstract
,
out_running_var_abstract
=
\
...
@@ -1424,8 +1427,9 @@ def test_dnn_batchnorm_train():
...
@@ -1424,8 +1427,9 @@ def test_dnn_batchnorm_train():
# reference backward pass
# reference backward pass
grads_ref
=
T
.
grad
(
None
,
wrt
=
[
x
,
scale
,
bias
],
known_grads
=
{
out_ref
:
dy
})
grads_ref
=
T
.
grad
(
None
,
wrt
=
[
x
,
scale
,
bias
],
known_grads
=
{
out_ref
:
dy
})
# compile
# compile
f_gpu
=
theano
.
function
([
x
,
scale
,
bias
,
dy
],
f_gpu
=
theano
.
function
([
x
,
scale
,
bias
,
running_mean
,
running_var
,
dy
],
[
out_gpu
,
x_mean_gpu
,
x_invstd_gpu
]
+
grads_gpu
,
[
out_gpu
,
x_mean_gpu
,
x_invstd_gpu
,
out_running_mean_gpu
,
out_running_var_gpu
]
+
grads_gpu
,
mode
=
mode_with_gpu
)
mode
=
mode_with_gpu
)
f_abstract
=
theano
.
function
([
x
,
scale
,
bias
,
running_mean
,
running_var
,
dy
],
f_abstract
=
theano
.
function
([
x
,
scale
,
bias
,
running_mean
,
running_var
,
dy
],
[
out_abstract
,
x_mean_abstract
,
x_invstd_abstract
,
[
out_abstract
,
x_mean_abstract
,
x_invstd_abstract
,
...
@@ -1455,13 +1459,16 @@ def test_dnn_batchnorm_train():
...
@@ -1455,13 +1459,16 @@ def test_dnn_batchnorm_train():
Bias
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
Bias
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
Running_mean
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
Running_mean
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
Running_var
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
Running_var
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
outputs_gpu
=
f_gpu
(
X
,
Scale
,
Bias
,
Dy
)
outputs_gpu
=
f_gpu
(
X
,
Scale
,
Bias
,
Running_mean
,
Running_var
,
Dy
)
outputs_abstract
=
f_abstract
(
X
,
Scale
,
Bias
,
Running_mean
,
Running_var
,
Dy
)
outputs_abstract
=
f_abstract
(
X
,
Scale
,
Bias
,
Running_mean
,
Running_var
,
Dy
)
outputs_ref
=
f_ref
(
X
,
Scale
,
Bias
,
Running_mean
,
Running_var
,
Dy
)
outputs_ref
=
f_ref
(
X
,
Scale
,
Bias
,
Running_mean
,
Running_var
,
Dy
)
# compare outputs
# compare outputs
utt
.
assert_allclose
(
outputs_gpu
[
0
],
outputs_ref
[
0
])
# out
utt
.
assert_allclose
(
outputs_gpu
[
0
],
outputs_ref
[
0
])
# out
utt
.
assert_allclose
(
outputs_gpu
[
1
],
outputs_ref
[
1
])
# mean
utt
.
assert_allclose
(
outputs_gpu
[
1
],
outputs_ref
[
1
])
# mean
utt
.
assert_allclose
(
outputs_gpu
[
2
],
outputs_ref
[
2
])
# invstd
utt
.
assert_allclose
(
outputs_gpu
[
2
],
outputs_ref
[
2
])
# invstd
utt
.
assert_allclose
(
outputs_gpu
[
3
],
outputs_ref
[
3
])
# running_mean
utt
.
assert_allclose
(
numpy
.
nan_to_num
(
outputs_gpu
[
4
]),
numpy
.
nan_to_num
(
outputs_ref
[
4
]))
# running_var
utt
.
assert_allclose
(
outputs_abstract
[
0
],
outputs_ref
[
0
])
# out
utt
.
assert_allclose
(
outputs_abstract
[
0
],
outputs_ref
[
0
])
# out
utt
.
assert_allclose
(
outputs_abstract
[
1
],
outputs_ref
[
1
])
# mean
utt
.
assert_allclose
(
outputs_abstract
[
1
],
outputs_ref
[
1
])
# mean
utt
.
assert_allclose
(
outputs_abstract
[
2
],
outputs_ref
[
2
])
# invstd
utt
.
assert_allclose
(
outputs_abstract
[
2
],
outputs_ref
[
2
])
# invstd
...
@@ -1469,9 +1476,9 @@ def test_dnn_batchnorm_train():
...
@@ -1469,9 +1476,9 @@ def test_dnn_batchnorm_train():
utt
.
assert_allclose
(
numpy
.
nan_to_num
(
outputs_abstract
[
4
]),
utt
.
assert_allclose
(
numpy
.
nan_to_num
(
outputs_abstract
[
4
]),
numpy
.
nan_to_num
(
outputs_ref
[
4
]))
# running_var
numpy
.
nan_to_num
(
outputs_ref
[
4
]))
# running_var
# compare gradients
# compare gradients
utt
.
assert_allclose
(
outputs_gpu
[
3
],
outputs_ref
[
5
],
atol
=
2e-4
)
# dx
utt
.
assert_allclose
(
outputs_gpu
[
5
],
outputs_ref
[
5
],
atol
=
2e-4
)
# dx
utt
.
assert_allclose
(
outputs_gpu
[
4
],
outputs_ref
[
6
],
rtol
=
4e-4
,
atol
=
1e-4
)
# dscale
utt
.
assert_allclose
(
outputs_gpu
[
6
],
outputs_ref
[
6
],
rtol
=
4e-4
,
atol
=
1e-4
)
# dscale
utt
.
assert_allclose
(
outputs_gpu
[
5
],
outputs_ref
[
7
])
# dbias
utt
.
assert_allclose
(
outputs_gpu
[
7
],
outputs_ref
[
7
])
# dbias
utt
.
assert_allclose
(
outputs_abstract
[
5
],
outputs_ref
[
5
],
atol
=
2e-4
)
# dx
utt
.
assert_allclose
(
outputs_abstract
[
5
],
outputs_ref
[
5
],
atol
=
2e-4
)
# dx
utt
.
assert_allclose
(
outputs_abstract
[
6
],
outputs_ref
[
6
],
rtol
=
4e-4
,
atol
=
1e-4
)
# dscale
utt
.
assert_allclose
(
outputs_abstract
[
6
],
outputs_ref
[
6
],
rtol
=
4e-4
,
atol
=
1e-4
)
# dscale
utt
.
assert_allclose
(
outputs_abstract
[
7
],
outputs_ref
[
7
])
# dbias
utt
.
assert_allclose
(
outputs_abstract
[
7
],
outputs_ref
[
7
])
# dbias
...
@@ -1490,11 +1497,22 @@ def test_dnn_batchnorm_train_without_running_averages():
...
@@ -1490,11 +1497,22 @@ def test_dnn_batchnorm_train_without_running_averages():
param_shape
=
(
1
,
10
,
30
,
25
)
param_shape
=
(
1
,
10
,
30
,
25
)
# forward pass
# forward pass
out
,
x_mean
,
x_invstd
=
bn
.
batch_normalization_train
(
x
,
scale
,
bias
,
'per-activation'
)
out_gpu
,
x_mean_gpu
,
x_invstd_gpu
=
\
dnn
.
dnn_batch_normalization_train
(
x
,
scale
,
bias
,
'per-activation'
)
out_abstract
,
x_mean_abstract
,
x_invstd_abstract
=
\
bn
.
batch_normalization_train
(
x
,
scale
,
bias
,
'per-activation'
)
# backward pass
# backward pass
grads
=
T
.
grad
(
None
,
wrt
=
[
x
,
scale
,
bias
],
known_grads
=
{
out
:
dy
})
grads_gpu
=
T
.
grad
(
None
,
wrt
=
[
x
,
scale
,
bias
],
known_grads
=
{
out_gpu
:
dy
})
grads_abstract
=
T
.
grad
(
None
,
wrt
=
[
x
,
scale
,
bias
],
known_grads
=
{
out_gpu
:
dy
})
# compile
# compile
f_abstract
=
theano
.
function
([
x
,
scale
,
bias
,
dy
],
[
out
,
x_mean
,
x_invstd
]
+
grads
,
mode
=
mode_with_gpu
)
f_gpu
=
theano
.
function
([
x
,
scale
,
bias
,
dy
],
[
out_gpu
,
x_mean_gpu
,
x_invstd_gpu
]
+
grads_gpu
,
mode
=
mode_with_gpu
)
f_abstract
=
theano
.
function
([
x
,
scale
,
bias
,
dy
],
[
out_abstract
,
x_mean_abstract
,
x_invstd_abstract
]
+
grads_abstract
,
mode
=
mode_with_gpu
)
# check if the abstract Ops have been replaced
# check if the abstract Ops have been replaced
assert
any
([
isinstance
(
n
.
op
,
dnn
.
GpuDnnBatchNorm
)
assert
any
([
isinstance
(
n
.
op
,
dnn
.
GpuDnnBatchNorm
)
for
n
in
f_abstract
.
maker
.
fgraph
.
toposort
()])
for
n
in
f_abstract
.
maker
.
fgraph
.
toposort
()])
...
@@ -1509,6 +1527,7 @@ def test_dnn_batchnorm_train_without_running_averages():
...
@@ -1509,6 +1527,7 @@ def test_dnn_batchnorm_train_without_running_averages():
Dy
=
-
1
+
2
*
numpy
.
random
.
randn
(
*
data_shape
)
.
astype
(
theano
.
config
.
floatX
)
Dy
=
-
1
+
2
*
numpy
.
random
.
randn
(
*
data_shape
)
.
astype
(
theano
.
config
.
floatX
)
Scale
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
Scale
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
Bias
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
Bias
=
numpy
.
random
.
randn
(
*
param_shape
)
.
astype
(
theano
.
config
.
floatX
)
f_gpu
(
X
,
Scale
,
Bias
,
Dy
)
f_abstract
(
X
,
Scale
,
Bias
,
Dy
)
f_abstract
(
X
,
Scale
,
Bias
,
Dy
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论