提交 372bab54 authored 作者: f0k's avatar f0k

Fixed conv_fft graph optimizer to enforce original broadcast pattern

上级 fc548c69
...@@ -1222,7 +1222,12 @@ def _gpu_conv_to_fftconv(node): ...@@ -1222,7 +1222,12 @@ def _gpu_conv_to_fftconv(node):
(len(node.op.imshp) == 3) and (len(node.op.imshp) == 3) and
(node.op.imshp[0] is not None)): (node.op.imshp[0] is not None)):
kwargs['filter_shape'] = (node.op.nkern, node.op.imshp[0]) + node.op.kshp kwargs['filter_shape'] = (node.op.nkern, node.op.imshp[0]) + node.op.kshp
return conv2d_fft(node.inputs[0], node.inputs[1], **kwargs) rval = conv2d_fft(node.inputs[0], node.inputs[1], **kwargs)
if ('image_shape' in kwargs) or ('filter_shape' in kwargs):
# With given shape information, conv2d_fft may return a different
# broadcast pattern than GpuConv. This is forbidden, so we fix it.
rval = tensor.patternbroadcast(rval, node.outputs[0].type.broadcastable)
return rval
@local_optimizer([GpuConv]) @local_optimizer([GpuConv])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论