提交 f8a4ee58 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3165 from carriepl/scan_pushout_opt_err

Fix broadcastable pattern of flatten output
...@@ -4190,8 +4190,18 @@ class Flatten(Op): ...@@ -4190,8 +4190,18 @@ class Flatten(Op):
if self.outdim < 1 or (x.ndim and self.outdim > x.ndim): if self.outdim < 1 or (x.ndim and self.outdim > x.ndim):
raise ValueError('invalid output ndimensions (%i) for tensor of ' raise ValueError('invalid output ndimensions (%i) for tensor of '
'rank %i' % (self.outdim, t_x.ndim)) 'rank %i' % (self.outdim, t_x.ndim))
# Infer the broadcastable pattern of the output. For every dimension
# unaffected by the flatten, the broadcast flag should be unchanged.
# For the dimension resulting from the collapse of other dimensions,
# it should be broadcastable iff all the collapsed dimensions were
# broadcastable.
bcast_kept_dims = x.broadcastable[:self.outdim - 1]
bcast_new_dim = python_all(x.broadcastable[self.outdim - 1:])
broadcastable = bcast_kept_dims + (bcast_new_dim,)
return gof.Apply(self, [t_x], [tensor(x.type.dtype, return gof.Apply(self, [t_x], [tensor(x.type.dtype,
(False,) * self.outdim)]) broadcastable)])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, = inp x, = inp
......
...@@ -5100,6 +5100,31 @@ def test_flatten_outdim2_of_3(): ...@@ -5100,6 +5100,31 @@ def test_flatten_outdim2_of_3():
utt.verify_grad(Flatten(2), [a_val]) utt.verify_grad(Flatten(2), [a_val])
def test_flatten_broadcastable():
# Ensure that the broadcastable pattern of the output is coherent with
# that of the input
inp = TensorType('float64', (False, False, False, False))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, False, False, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, False, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, True, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, True)
inp = TensorType('float64', (True, False, True, True))()
out = flatten(inp, outdim=3)
assert out.broadcastable == (True, False, True)
def test_flatten_outdim_invalid(): def test_flatten_outdim_invalid():
a = dmatrix() a = dmatrix()
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论