提交 4d82d018 authored 作者: James Bergstra's avatar James Bergstra

ConvOp - corrected broadcastable flags in output

上级 8d761b6c
......@@ -382,7 +382,7 @@ class ConvOp(Op):
self.unroll_kern=new
if all_shape:
self.outshp = ConvOp.getOutputShape(self.imshp_logical[1:], self.kshp_logical, (dx,dy), output_mode)
self.outshp = ConvOp.getOutputShape(self.imshp_logical[1:], self.kshp_logical, (dx,dy), output_mode)
self.fulloutshp = ConvOp.getOutputShape(self.imshp_logical[1:], self.kshp_logical, (1,1), output_mode)
else:
self.outshp = None
......@@ -478,9 +478,13 @@ class ConvOp(Op):
if _inputs.type.dtype != _kerns.type.dtype:
raise NotImplementedError("The image and the kernel must have the same type."
"inputs(%s), kerns(%s)"%(_inputs.dtype, _kerns.dtype))
if self.outshp is not None:
bcastable23 = [self.outshp[0]==1, self.outshp[1]==1]
else:
bcastable23 = [False, False]
output = tensor.tensor(dtype=_inputs.type.dtype,
broadcastable=[_inputs.broadcastable[0],
_kerns.broadcastable[0], False, False]);
_kerns.broadcastable[0]]+bcastable23);
return Apply(self, [_inputs, _kerns], [output])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论