提交 79f24d78 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

fix broadcast error

上级 daf47e46
...@@ -2114,8 +2114,12 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2114,8 +2114,12 @@ class AbstractConv_gradInputs(BaseAbstractConv):
'filters does not match given kshp.') 'filters does not match given kshp.')
shape = as_tensor_variable(shape) shape = as_tensor_variable(shape)
broadcastable = [topgrad.type.broadcastable[0], if self.num_groups > 1:
kern.type.broadcastable[1]] + ([False] * self.convdim) broadcastable = [topgrad.type.broadcastable[0],
False] + ([False] * self.convdim)
else:
broadcastable = [topgrad.type.broadcastable[0],
kern.type.broadcastable[1]] + ([False] * self.convdim)
output = kern.type.clone(broadcastable=broadcastable)() output = kern.type.clone(broadcastable=broadcastable)()
return Apply(self, [kern, topgrad, shape], [output]) return Apply(self, [kern, topgrad, shape], [output])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论