提交 79262793 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix broadcastable pattern in abstract gradW

上级 fc1d81a1
...@@ -245,8 +245,8 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -245,8 +245,8 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
' or border_mode == "half"') ' or border_mode == "half"')
shape = as_tensor_variable(shape) shape = as_tensor_variable(shape)
broadcastable=[topgrad.broadcastable[0], broadcastable=[topgrad.broadcastable[1],
img.broadcastable[0], img.broadcastable[1],
False, False] False, False]
output = img.type.clone(broadcastable=broadcastable)() output = img.type.clone(broadcastable=broadcastable)()
return Apply(self, [img, topgrad, shape], [output]) return Apply(self, [img, topgrad, shape], [output])
...@@ -501,8 +501,8 @@ def local_conv2d_gradweight_corrmm(node): ...@@ -501,8 +501,8 @@ def local_conv2d_gradweight_corrmm(node):
gpu_contiguous(img), gpu_contiguous(topgrad), shape) gpu_contiguous(img), gpu_contiguous(topgrad), shape)
if node.op.filters_flip: if node.op.filters_flip:
rval = rval[:, :, ::-1, ::-1] rval = rval[:, :, ::-1, ::-1]
rval = as_cuda_ndarray_variable(rval) rval = patternbroadcast(rval, node.outputs[0].broadcastable)
#rval = patternbroadcast(rval, node.outputs[0].broadcastable) rval = as_cuda_ndarray_variable(rval)
return [rval] return [rval]
register_specialize_device(local_conv2d_gradweight_corrmm, 'conv_gemm') register_specialize_device(local_conv2d_gradweight_corrmm, 'conv_gemm')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论