提交 fc548c69 authored 作者: f0k's avatar f0k

Added required as_cuda_ndarray_variable in conv_gemm optimizer

上级 2c3bbdcf
......@@ -1391,10 +1391,14 @@ def local_conv_gemm(node):
prod2 *= node.op.imshp[0]
# compare to decide
if prod1 > prod2:
return [GpuCorrMM_gradWeights('valid', subsample, pad)(
# (we need to wrap the result in as_cuda_ndarray_variable,
# because we are not allowed to replace a CudaNdarray with
# a DimShuffle instance in a graph optimization)
return [theano.sandbox.cuda.as_cuda_ndarray_variable(
GpuCorrMM_gradWeights('valid', subsample, pad)(
gpu_contiguous(img.dimshuffle(1, 0, 2, 3)),
gpu_contiguous(kern.dimshuffle(1, 0, 2, 3))
).dimshuffle(1, 0, 2, 3)]
).dimshuffle(1, 0, 2, 3))]
# use GpuCorrMM if we did not choose GpuCorrMM_gradWeights above
return [GpuCorrMM('valid', subsample, pad)(
gpu_contiguous(img), gpu_contiguous(kern))]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论