提交 fc1d81a1 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix for corrmm optimization.

I'm not completely sure about that one.
上级 e0b1fc7c
...@@ -496,11 +496,13 @@ def local_conv2d_gradweight_corrmm(node): ...@@ -496,11 +496,13 @@ def local_conv2d_gradweight_corrmm(node):
if not isinstance(img.type, CudaNdarrayType) or \ if not isinstance(img.type, CudaNdarrayType) or \
not isinstance(topgrad.type, CudaNdarrayType): not isinstance(topgrad.type, CudaNdarrayType):
return None return None
if node.op.filters_flip:
img = img[:, :, ::-1, ::-1]
rval = GpuCorrMM_gradWeights(border_mode=node.op.border_mode, rval = GpuCorrMM_gradWeights(border_mode=node.op.border_mode,
subsample=node.op.subsample)( subsample=node.op.subsample)(
gpu_contiguous(img), gpu_contiguous(topgrad), shape) gpu_contiguous(img), gpu_contiguous(topgrad), shape)
if node.op.filters_flip:
rval = rval[:, :, ::-1, ::-1]
rval = as_cuda_ndarray_variable(rval)
#rval = patternbroadcast(rval, node.outputs[0].broadcastable)
return [rval] return [rval]
register_specialize_device(local_conv2d_gradweight_corrmm, 'conv_gemm') register_specialize_device(local_conv2d_gradweight_corrmm, 'conv_gemm')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论