提交 60238616 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Improve loading of old batch norm pickles.

上级 db9db1d0
......@@ -1689,6 +1689,20 @@ class GpuDnnBatchNorm(DnnBase):
if self.running_averages and self.inplace_running_var:
self.destroy_map[4] = [6]
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'running_average_factor'):
self.running_average_factor = 0
if not hasattr(self, 'running_averages'):
self.running_averages = False
if not (hasattr(self, 'inplace_running_mean') and
hasattr(self, 'inplace_running_var') and
hasattr(self, 'inplace_output')):
self.inplace_running_mean = False
self.inplace_running_var = False
self.inplace_output = False
self.destroy_map = {}
def get_op_params(self):
params = []
if self.inplace_output:
......@@ -1788,6 +1802,11 @@ class GpuDnnBatchNormInference(DnnBase):
if self.inplace:
self.destroy_map = {0: [0]}
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'inplace'):
self.inplace = False
def get_op_params(self):
params = []
if self.inplace:
......
......@@ -2447,6 +2447,11 @@ class GpuDnnBatchNormInference(GpuDnnBatchNormBase):
if self.inplace:
self.destroy_map = {0: [0]}
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'inplace'):
self.inplace = False
def get_op_params(self):
params = []
if self.inplace:
......@@ -2582,9 +2587,9 @@ class GpuDnnBatchNorm(GpuDnnBatchNormBase):
Note: scale and bias must follow the same tensor layout!
"""
__props__ = ('mode', 'epsilon', 'running_averages',
'inplace_running_mean', 'inplace_running_var',
'inplace_output')
__props__ = ('mode', 'epsilon', 'running_average_factor',
'running_averages', 'inplace_running_mean',
'inplace_running_var', 'inplace_output')
tensor_descs = ['bn_input', 'bn_output', 'bn_params']
def __init__(self, mode='per-activation', epsilon=1e-4,
......@@ -2605,6 +2610,20 @@ class GpuDnnBatchNorm(GpuDnnBatchNormBase):
if self.running_averages and self.inplace_running_var:
self.destroy_map[4] = [4]
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'running_average_factor'):
self.running_average_factor = 0
if not hasattr(self, 'running_averages'):
self.running_averages = False
if not (hasattr(self, 'inplace_running_mean') and
hasattr(self, 'inplace_running_var') and
hasattr(self, 'inplace_output')):
self.inplace_running_mean = False
self.inplace_running_var = False
self.inplace_output = False
self.destroy_map = {}
def get_op_params(self):
params = []
if self.inplace_output:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论