提交 e0afde95 authored 作者: sebastien-j's avatar sebastien-j

Address nouiz's comments

上级 01f4490c
...@@ -313,8 +313,7 @@ class DownsampleFactorMax(Op): ...@@ -313,8 +313,7 @@ class DownsampleFactorMax(Op):
maxout = self(x) maxout = self(x)
return [MaxPoolGrad(self.ds, return [MaxPoolGrad(self.ds,
ignore_border=self.ignore_border, ignore_border=self.ignore_border,
st=self.st, padding=self.padding, st=self.st, padding=self.padding)(
mode=self.mode)(
x, maxout, gz)] x, maxout, gz)]
else: else:
return [AveragePoolGrad(self.ds, return [AveragePoolGrad(self.ds,
...@@ -607,9 +606,8 @@ class PoolGrad(Op): ...@@ -607,9 +606,8 @@ class PoolGrad(Op):
class MaxPoolGrad(PoolGrad): class MaxPoolGrad(PoolGrad):
def __init__(self, ds, ignore_border, st=None, padding=(0, 0), mode='max'): def __init__(self, ds, ignore_border, st=None, padding=(0, 0)):
PoolGrad.__init__(self, ds, ignore_border, st, padding, mode) PoolGrad.__init__(self, ds, ignore_border, st, padding, mode='max')
self.mode = 'max'
def make_node(self, x, maxout, gz): def make_node(self, x, maxout, gz):
# make_node should only be called by the grad function of # make_node should only be called by the grad function of
...@@ -784,12 +782,13 @@ class MaxPoolGrad(PoolGrad): ...@@ -784,12 +782,13 @@ class MaxPoolGrad(PoolGrad):
} }
""" % locals() """ % locals()
#def c_code_cache_version(self): def c_code_cache_version(self):
# return (0, 7) return (0, 7)
class AveragePoolGrad(PoolGrad): class AveragePoolGrad(PoolGrad):
def __init__(self, ds, ignore_border, st=None, padding=(0, 0), mode='average_inc_pad'): def __init__(self, ds, ignore_border, st=None, padding=(0, 0), mode='average_inc_pad'):
assert mode in ['sum', 'average_inc_pad', 'average_exc_pad']
PoolGrad.__init__(self, ds, ignore_border, st, padding, mode) PoolGrad.__init__(self, ds, ignore_border, st, padding, mode)
def make_node(self, x, gz): def make_node(self, x, gz):
...@@ -980,8 +979,6 @@ class DownsampleFactorMaxGradGrad(Op): ...@@ -980,8 +979,6 @@ class DownsampleFactorMaxGradGrad(Op):
return Apply(self, [x, maxout, gz], [x.type()]) return Apply(self, [x, maxout, gz], [x.type()])
def perform(self, node, inp, out): def perform(self, node, inp, out):
if self.mode != 'max':
raise theano.gof.utils.MethodNotDefined()
x, maxout, ggx = inp x, maxout, ggx = inp
z, = out z, = out
if len(x.shape) != 4: if len(x.shape) != 4:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论