提交 e1c65a56 authored 作者: Harm de Vries's avatar Harm de Vries

remove test

上级 ca19b615
...@@ -1390,12 +1390,24 @@ class GpuDnnPool(DnnBase): ...@@ -1390,12 +1390,24 @@ class GpuDnnPool(DnnBase):
node.inputs[1] = ws node.inputs[1] = ws
node.inputs.append(st) node.inputs.append(st)
node.inputs.append(pad) node.inputs.append(pad)
storage_map[ws] = [None] if isinstance(ws, theano.constant):
storage_map[st] = [None] storage_map[ws] = [ws.data]
storage_map[pad] = [None] compute_map[ws] = [True]
compute_map[ws] = [False] else:
compute_map[st] = [False] storage_map[ws] = [None]
compute_map[pad] = [False] compute_map[ws] = [False]
if isinstance(st, theano.constant):
storage_map[st] = [st.data]
compute_map[st] = [True]
else:
storage_map[st] = [None]
compute_map[st] = [False]
if isinstance(pad, theano.constant):
storage_map[pad] = [pad.data]
compute_map[pad] = [True]
else:
storage_map[pad] = [None]
compute_map[pad] = [False]
def make_node(self, img, ws, stride, pad): def make_node(self, img, ws, stride, pad):
img = as_cuda_ndarray_variable(img) img = as_cuda_ndarray_variable(img)
...@@ -1615,12 +1627,24 @@ class GpuDnnPoolGrad(DnnBase): ...@@ -1615,12 +1627,24 @@ class GpuDnnPoolGrad(DnnBase):
node.inputs[3] = ws node.inputs[3] = ws
node.inputs.append(st) node.inputs.append(st)
node.inputs.append(pad) node.inputs.append(pad)
storage_map[ws] = [None] if isinstance(ws, theano.constant):
storage_map[st] = [None] storage_map[ws] = [ws.data]
storage_map[pad] = [None] compute_map[ws] = [True]
compute_map[ws] = [False] else:
compute_map[st] = [False] storage_map[ws] = [None]
compute_map[pad] = [False] compute_map[ws] = [False]
if isinstance(st, theano.constant):
storage_map[st] = [st.data]
compute_map[st] = [True]
else:
storage_map[st] = [None]
compute_map[st] = [False]
if isinstance(pad, theano.constant):
storage_map[pad] = [pad.data]
compute_map[pad] = [True]
else:
storage_map[pad] = [None]
compute_map[pad] = [False]
def make_node(self, inp, out, inp_grad, ws, stride, pad): def make_node(self, inp, out, inp_grad, ws, stride, pad):
inp = as_cuda_ndarray_variable(inp) inp = as_cuda_ndarray_variable(inp)
......
...@@ -72,19 +72,6 @@ def test_dnn_conv_desc_merge(): ...@@ -72,19 +72,6 @@ def test_dnn_conv_desc_merge():
assert d1 == d2 assert d1 == d2
def test_dnn_pool_desc_merge():
if not cuda.dnn.dnn_available():
raise SkipTest(cuda.dnn.dnn_available.msg)
x = theano.tensor.ftensor4('x')
y = dnn.dnn_pool(x, (2, 2))
z = dnn.dnn_pool(x, (2, 2))
f = theano.function([x], [y, z])
descs = [n for n in f.maker.fgraph.apply_nodes
if isinstance(n.op, dnn.GpuDnnPoolDesc)]
assert len(descs) == 1, f.maker.fgraph
def test_dnn_conv_merge(): def test_dnn_conv_merge():
"""This test that we merge correctly multiple dnn_conv. """This test that we merge correctly multiple dnn_conv.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论