提交 4201bf7c authored 作者: Frederic's avatar Frederic

[ENH] Simplify DownsampleFactorMax.out_shape that is used by its infer_shape when strides == ds

上级 e4089d8d
...@@ -193,16 +193,23 @@ class DownsampleFactorMax(Op): ...@@ -193,16 +193,23 @@ class DownsampleFactorMax(Op):
c += padding[1] * 2 c += padding[1] * 2
if ignore_border: if ignore_border:
out_r = (r - ds[0]) // st[0] + 1 if ds[0] == st[0]:
out_c = (c - ds[1]) // st[1] + 1 nr = r // st[0]
if isinstance(r, theano.Variable):
nr = tensor.maximum(out_r, 0)
else: else:
nr = numpy.maximum(out_r, 0) out_r = (r - ds[0]) // st[0] + 1
if isinstance(c, theano.Variable): if isinstance(r, theano.Variable):
nc = tensor.maximum(out_c, 0) nr = tensor.maximum(out_r, 0)
else:
nr = numpy.maximum(out_r, 0)
if ds[1] == st[1]:
nc = c // st[1]
else: else:
nc = numpy.maximum(out_c, 0) out_c = (c - ds[1]) // st[1] + 1
if isinstance(c, theano.Variable):
nc = tensor.maximum(out_c, 0)
else:
nc = numpy.maximum(out_c, 0)
else: else:
if isinstance(r, theano.Variable): if isinstance(r, theano.Variable):
nr = tensor.switch(tensor.ge(st[0], ds[0]), nr = tensor.switch(tensor.ge(st[0], ds[0]),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论