提交 67495a12 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix infer_shape for pooling, skip unimplemented test cases

上级 66bb1078
...@@ -835,11 +835,12 @@ class GpuDnnPool(DnnBase): ...@@ -835,11 +835,12 @@ class GpuDnnPool(DnnBase):
desc = node.inputs[1].owner.op desc = node.inputs[1].owner.op
kh, kw = desc.ws kh, kw = desc.ws
sh, sw = desc.stride sh, sw = desc.stride
padh, padw = desc.pad
return [( return [(
shape[0][0], shape[0][0],
shape[0][1], shape[0][1],
(shape[0][2] - kh)//sh + 1, (shape[0][2] + 2*padh - kh)//sh + 1,
(shape[0][3] - kw)//sw + 1 (shape[0][3] + 2*padw - kw)//sw + 1
)] )]
def c_support_code_struct(self, node, name): def c_support_code_struct(self, node, name):
......
...@@ -73,11 +73,17 @@ def test_pooling(): ...@@ -73,11 +73,17 @@ def test_pooling():
if pad != (0, 0) and cuda.dnn.version() < 20: if pad != (0, 0) and cuda.dnn.version() < 20:
continue continue
if pad != (0, 0) and func is T.mean:
continue
for ws in (4, 2, 5): for ws in (4, 2, 5):
for stride in (2, 3): for stride in (2, 3):
if stride > ws: if stride > ws:
continue continue
if func is T.max: if func is T.max:
if pad[0] > stride or pad[1] > stride:
# Not implemented
continue
# We will check that the opt introduced it. # We will check that the opt introduced it.
out1 = max_pool_2d(x, (ws, ws), out1 = max_pool_2d(x, (ws, ws),
st=(stride, stride), st=(stride, stride),
...@@ -117,6 +123,9 @@ def test_pooling(): ...@@ -117,6 +123,9 @@ def test_pooling():
ws = 2 ws = 2
stride = 2 stride = 2
if pad[0] > stride or pad[1] > stride:
# Not implemented
continue
# This test the CPU grad + opt + GPU implemtentation # This test the CPU grad + opt + GPU implemtentation
def fn(x): def fn(x):
......
...@@ -257,7 +257,7 @@ class DownsampleFactorMax(Op): ...@@ -257,7 +257,7 @@ class DownsampleFactorMax(Op):
def infer_shape(self, node, in_shapes): def infer_shape(self, node, in_shapes):
shp = self.out_shape(in_shapes[0], self.ds, shp = self.out_shape(in_shapes[0], self.ds,
self.ignore_border, self.st) self.ignore_border, self.st, self.padding)
return [shp] return [shp]
def grad(self, inp, grads): def grad(self, inp, grads):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论