提交 f08b1cd3 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add test for multiple clients of convolution.

上级 bcbcc1b3
......@@ -659,6 +659,33 @@ def test_dnn_conv_alpha_output_merge():
utt.assert_allclose(v1, v2)
def test_dnn_conv_merge_mouts():
# make sure it doesn't attempt to output/alpha merge a convolution
# that has multiple clients.
if not cuda.dnn.dnn_available():
raise SkipTest(cuda.dnn.dnn_available.msg)
img = T.ftensor4()
kern = T.ftensor4()
out = T.ftensor4()
conv = dnn.dnn_conv(img, kern)
lr = numpy.asarray(0.05, dtype='float32')
if cuda.dnn.version() == -1:
# Can't merge alpha with cudnn v1
fr = conv + out
else:
fr = lr * (conv + out)
rr = conv * lr
f = theano.function([img, kern, out], [fr, rr], mode=mode_with_gpu)
assert not isinstance(f.maker.fgraph.outputs[0].owner.inputs[0].owner.op,
dnn.GpuDnnConv)
assert not isinstance(f.maker.fgraph.outputs[1].owner.inputs[0].owner.op,
dnn.GpuDnnConv)
def test_dnn_conv_grad():
if not cuda.dnn.dnn_available() or dnn.version() == -1:
raise SkipTest('alpha != 1.0 not supported in cudnn v1')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论