提交 feba3ecb authored 作者: affanv14's avatar affanv14

add shape info for conv3d2d

上级 527b324e
......@@ -1754,10 +1754,13 @@ def local_abstractconv3d2d(node):
filter_dilation = node.op.filter_dilation
if subsample == (1, 1, 1) and filter_dilation == (1, 1, 1):
rval = conv3d2d.conv3d(gpu_contiguous(img.dimshuffle(0, 2, 1, 3, 4)),
gpu_contiguous(kern.dimshuffle(0, 2, 1, 3, 4)),
reorder_array = [0, 2, 1, 3, 4]
rval = conv3d2d.conv3d(gpu_contiguous(img.dimshuffle(*reorder_array)),
gpu_contiguous(kern.dimshuffle(*reorder_array)),
[node.op.imshp[i] for i in reorder_array],
[node.op.kshp[i] for i in reorder_array],
border_mode=border_mode)
rval = as_gpuarray_variable(rval.dimshuffle(0, 2, 1, 3, 4),
rval = as_gpuarray_variable(rval.dimshuffle(*reorder_array),
context_name=ctx)
return [rval]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论