提交 4d659548 authored 作者: --global's avatar --global

Fix broadcastable pattern of flatten output

上级 c356183f
......@@ -4190,8 +4190,18 @@ class Flatten(Op):
if self.outdim < 1 or (x.ndim and self.outdim > x.ndim):
raise ValueError('invalid output ndimensions (%i) for tensor of '
'rank %i' % (self.outdim, t_x.ndim))
# Infer the broadcastable pattern of the output. For every dimension
# unaffected by the flatten, the broadcast flag should be unchanged.
# For the dimension resulting from the collapse of other dimensions,
# it should be broadcastable iff all the collapsed dimensions were
# broadcastable.
bcast_kept_dims = x.broadcastable[:self.outdim - 1]
bcast_new_dim = python_all(x.broadcastable[self.outdim - 1:])
broadcastable = bcast_kept_dims + (bcast_new_dim,)
return gof.Apply(self, [t_x], [tensor(x.type.dtype,
(False,) * self.outdim)])
broadcastable)])
def perform(self, node, inp, out_):
x, = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论