提交 d878af4f authored 作者: Harm de Vries's avatar Harm de Vries 提交者: Frederic Bastien

fix other tests

上级 26d47b89
......@@ -1124,7 +1124,7 @@ class GpuDnnPool(DnnBase):
return [('MODE_FLAG', mode_flag)]
def make_node(self, img, ws, stride, pad):
ctx_name = infer_context_name(img, ws, stride, pad)
ctx_name = infer_context_name(img)
img = as_gpuarray_variable(img, ctx_name)
ws = tensor.as_tensor_variable(ws)
......
......@@ -610,14 +610,9 @@ class TestDnnInferShapes(utt.InferShapeTester):
[(1, 1), (2, 2), (3, 3)],
modes
):
desc = dnn.GpuDnnPoolDesc(
ws=params[0],
stride=params[1],
mode=params[2]
)()
self._compile_and_check(
[img],
[dnn.GpuDnnPool()(img, desc)],
[dnn.GpuDnnPool(mode=params[2])(img, params[0], params[1], (0, 0))],
[img_val],
dnn.GpuDnnPool
)
......@@ -646,16 +641,13 @@ class TestDnnInferShapes(utt.InferShapeTester):
[(1, 1), (2, 2), (3, 3)],
['max', 'average_inc_pad']
):
desc = dnn.GpuDnnPoolDesc(
ws=params[0],
stride=params[1],
mode=params[2]
)()
pool_grad = dnn.GpuDnnPoolGrad()(
pool_grad = dnn.GpuDnnPoolGrad(mode=params[2])(
img,
out,
img_grad,
desc
params[0],
params[1],
(0, 0)
)
self._compile_and_check(
[img, img_grad, out],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论