提交 79f24d78 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

fix broadcast error

上级 daf47e46
...@@ -2114,6 +2114,10 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2114,6 +2114,10 @@ class AbstractConv_gradInputs(BaseAbstractConv):
'filters does not match given kshp.') 'filters does not match given kshp.')
shape = as_tensor_variable(shape) shape = as_tensor_variable(shape)
if self.num_groups > 1:
broadcastable = [topgrad.type.broadcastable[0],
False] + ([False] * self.convdim)
else:
broadcastable = [topgrad.type.broadcastable[0], broadcastable = [topgrad.type.broadcastable[0],
kern.type.broadcastable[1]] + ([False] * self.convdim) kern.type.broadcastable[1]] + ([False] * self.convdim)
output = kern.type.clone(broadcastable=broadcastable)() output = kern.type.clone(broadcastable=broadcastable)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论