提交 2c04d4d8 authored 作者: Sina Honari's avatar Sina Honari

considering the case when st != ds in the grad methods

上级 b90a03fb
...@@ -235,8 +235,6 @@ class DownsampleFactorMax(Op): ...@@ -235,8 +235,6 @@ class DownsampleFactorMax(Op):
x, = inp x, = inp
gz, = grads gz, = grads
maxout = self(x) maxout = self(x)
if self.st != self.ds:
return [theano.gradient.grad_not_implemented(self, 0, x)]
return [DownsampleFactorMaxGrad(self.ds, return [DownsampleFactorMaxGrad(self.ds,
ignore_border=self.ignore_border, ignore_border=self.ignore_border,
st=self.st)( st=self.st)(
...@@ -386,16 +384,14 @@ class DownsampleFactorMaxGrad(Op): ...@@ -386,16 +384,14 @@ class DownsampleFactorMaxGrad(Op):
def grad(self, inp, grads): def grad(self, inp, grads):
x, maxout, gz = inp x, maxout, gz = inp
ggx, = grads ggx, = grads
if self.st != self.ds:
return [theano.gradient.grad_not_implemented(self, 0, x),
theano.gradient.grad_not_implemented(self, 1, maxout),
theano.gradient.grad_not_implemented(self, 2, gz)]
return [theano.tensor.zeros_like(x), return [theano.tensor.zeros_like(x),
theano.tensor.zeros_like(maxout), theano.tensor.zeros_like(maxout),
DownsampleFactorMaxGradGrad( DownsampleFactorMaxGradGrad(
self.ds, ignore_border=self.ignore_border, st=self.st)(x, maxout, ggx)] self.ds, ignore_border=self.ignore_border, st=self.st)(x, maxout, ggx)]
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
if self.ds != self.st:
raise theano.gof.utils.MethodNotDefined()
x, z, gz = inp x, z, gz = inp
gx, = out gx, = out
fail = sub['fail'] fail = sub['fail']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论