提交 c386dd5d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Explain more how dummy is there for backward compat and cleanup a bit.

上级 7145695f
...@@ -17,7 +17,7 @@ class DownsampleFactorMaxGrad(object): ...@@ -17,7 +17,7 @@ class DownsampleFactorMaxGrad(object):
def __new__(self, ds, ignore_border, st=None, padding=(0, 0), mode='max'): def __new__(self, ds, ignore_border, st=None, padding=(0, 0), mode='max'):
if mode == 'max': if mode == 'max':
return MaxPoolGrad(ds=ds, ignore_border=ignore_border, st=st, return MaxPoolGrad(ds=ds, ignore_border=ignore_border, st=st,
padding=padding, mode='max') padding=padding)
else: else:
return AveragePoolGrad(ds=ds, ignore_border=ignore_border, st=st, return AveragePoolGrad(ds=ds, ignore_border=ignore_border, st=st,
padding=padding, mode=mode) padding=padding, mode=mode)
......
...@@ -625,9 +625,8 @@ class PoolGrad(Op): ...@@ -625,9 +625,8 @@ class PoolGrad(Op):
class MaxPoolGrad(PoolGrad): class MaxPoolGrad(PoolGrad):
def __init__(self, ds, ignore_border, st=None, padding=(0, 0)):
def __init__(self, ds, ignore_border, st=None, padding=(0, 0), mode='max'): PoolGrad.__init__(self, ds, ignore_border, st, padding, mode='max')
PoolGrad.__init__(self, ds, ignore_border, st, padding, mode)
def make_node(self, x, maxout, gz): def make_node(self, x, maxout, gz):
# make_node should only be called by the grad function of # make_node should only be called by the grad function of
...@@ -802,13 +801,15 @@ class MaxPoolGrad(PoolGrad): ...@@ -802,13 +801,15 @@ class MaxPoolGrad(PoolGrad):
class AveragePoolGrad(PoolGrad): class AveragePoolGrad(PoolGrad):
def __init__(self, ds, ignore_border, st=None, padding=(0, 0),
def __init__(self, ds, ignore_border, st=None, padding=(0, 0), mode='average_inc_pad'): mode='average_inc_pad'):
assert mode in ['sum', 'average_inc_pad', 'average_exc_pad'] assert mode in ['sum', 'average_inc_pad', 'average_exc_pad']
PoolGrad.__init__(self, ds, ignore_border, st, padding, mode) PoolGrad.__init__(self, ds, ignore_border, st, padding, mode)
# There is an extra dummy parameter to match the parameter count # There is an extra dummy parameter to match the parameter count
# of MaxPoolGrad. This is for backward compatibility. # of MaxPoolGrad. They have to keep the same interface because of
# the DownsampleFactorMaxGrad trick to keep old scripts working
# (see downsample.py for details on this).
def make_node(self, x, gz, dummy=None): def make_node(self, x, gz, dummy=None):
# make_node should only be called by the grad function of # make_node should only be called by the grad function of
# Pool, so these asserts should not fail. # Pool, so these asserts should not fail.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论