提交 04a94751 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix last problems with the merge opts and test.

上级 024ec750
...@@ -1515,7 +1515,7 @@ if True: ...@@ -1515,7 +1515,7 @@ if True:
def local_dnn_convi_alpha_merge(node, *inputs): def local_dnn_convi_alpha_merge(node, *inputs):
if version() == -1: if version() == -1:
return None return None
return [GpuDnnConvGradW()(*inputs)] return [GpuDnnConvGradI()(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@output_merge(GpuDnnConv, alpha_in=4, out_in=2, nd=4) @output_merge(GpuDnnConv, alpha_in=4, out_in=2, nd=4)
......
...@@ -459,11 +459,11 @@ def test_dnn_conv_merge(): ...@@ -459,11 +459,11 @@ def test_dnn_conv_merge():
ir = img - lr * gi ir = img - lr * gi
f1 = theano.function([img, kern, out], [fr, wr, ir], mode=mode_with_gpu) f1 = theano.function([img, kern, out], [fr, wr, ir], mode=mode_with_gpu)
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, assert isinstance(f1.maker.fgraph.outputs[0].owner.inputs[0].owner.op,
dnn.GpuDnnConv) dnn.GpuDnnConv)
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, assert isinstance(f1.maker.fgraph.outputs[1].owner.inputs[0].owner.op,
dnn.GpuDnnConvGradW) dnn.GpuDnnConvGradW)
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, assert isinstance(f1.maker.fgraph.outputs[2].owner.inputs[0].owner.op,
dnn.GpuDnnConvGradI) dnn.GpuDnnConvGradI)
mode = mode_with_gpu mode = mode_with_gpu
...@@ -476,11 +476,11 @@ def test_dnn_conv_merge(): ...@@ -476,11 +476,11 @@ def test_dnn_conv_merge():
f2 = theano.function([img, kern, out], [fr, wr, ir], mode=mode) f2 = theano.function([img, kern, out], [fr, wr, ir], mode=mode)
assert not isinstance(f1.maker.fgraph.outputs[0].owner.op, assert not isinstance(f2.maker.fgraph.outputs[0].owner.inputs[0].owner.op,
dnn.GpuDnnConv) dnn.GpuDnnConv)
assert not isinstance(f1.maker.fgraph.outputs[0].owner.op, assert not isinstance(f2.maker.fgraph.outputs[1].owner.inputs[0].owner.op,
dnn.GpuDnnConvGradW) dnn.GpuDnnConvGradW)
assert not isinstance(f1.maker.fgraph.outputs[0].owner.op, assert not isinstance(f2.maker.fgraph.outputs[2].owner.inputs[0].owner.op,
dnn.GpuDnnConvGradI) dnn.GpuDnnConvGradI)
out_f1 = f1(img_val, kern_val, out_val) out_f1 = f1(img_val, kern_val, out_val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论