提交 f8a4ee58 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3165 from carriepl/scan_pushout_opt_err

Fix broadcastable pattern of flatten output
...@@ -4190,8 +4190,18 @@ class Flatten(Op): ...@@ -4190,8 +4190,18 @@ class Flatten(Op):
if self.outdim < 1 or (x.ndim and self.outdim > x.ndim): if self.outdim < 1 or (x.ndim and self.outdim > x.ndim):
raise ValueError('invalid output ndimensions (%i) for tensor of ' raise ValueError('invalid output ndimensions (%i) for tensor of '
'rank %i' % (self.outdim, t_x.ndim)) 'rank %i' % (self.outdim, t_x.ndim))
# Infer the broadcastable pattern of the output. For every dimension
# unaffected by the flatten, the broadcast flag should be unchanged.
# For the dimension resulting from the collapse of other dimensions,
# it should be broadcastable iff all the collapsed dimensions were
# broadcastable.
bcast_kept_dims = x.broadcastable[:self.outdim - 1]
bcast_new_dim = python_all(x.broadcastable[self.outdim - 1:])
broadcastable = bcast_kept_dims + (bcast_new_dim,)
return gof.Apply(self, [t_x], [tensor(x.type.dtype, return gof.Apply(self, [t_x], [tensor(x.type.dtype,
(False,) * self.outdim)]) broadcastable)])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, = inp x, = inp
......
...@@ -1089,7 +1089,7 @@ _good_broadcast_unary_normal_float_no_empty_no_complex = copymod( ...@@ -1089,7 +1089,7 @@ _good_broadcast_unary_normal_float_no_empty_no_complex = copymod(
_good_broadcast_unary_normal_float_no_complex = copymod( _good_broadcast_unary_normal_float_no_complex = copymod(
_good_broadcast_unary_normal_float, _good_broadcast_unary_normal_float,
without=['complex']) without=['complex'])
_good_broadcast_unary_normal_float_no_complex_small_neg_range = dict( _good_broadcast_unary_normal_float_no_complex_small_neg_range = dict(
normal=[rand_ranged(-2, 5, (2, 3))], normal=[rand_ranged(-2, 5, (2, 3))],
corner_case=[corner_case], corner_case=[corner_case],
...@@ -1123,7 +1123,7 @@ _grad_broadcast_unary_normal = dict( ...@@ -1123,7 +1123,7 @@ _grad_broadcast_unary_normal = dict(
corner_case=[corner_case_grad], corner_case=[corner_case_grad],
# empty = [numpy.asarray([])] # XXX: should this be included? # empty = [numpy.asarray([])] # XXX: should this be included?
) )
_grad_broadcast_unary_normal_small_neg_range = dict( _grad_broadcast_unary_normal_small_neg_range = dict(
normal=[numpy.asarray(rand_ranged(-2, 5, (2, 3)), dtype=floatX)], normal=[numpy.asarray(rand_ranged(-2, 5, (2, 3)), dtype=floatX)],
corner_case=[corner_case_grad]) corner_case=[corner_case_grad])
...@@ -5100,6 +5100,31 @@ def test_flatten_outdim2_of_3(): ...@@ -5100,6 +5100,31 @@ def test_flatten_outdim2_of_3():
utt.verify_grad(Flatten(2), [a_val]) utt.verify_grad(Flatten(2), [a_val])
def test_flatten_broadcastable():
# Ensure that the broadcastable pattern of the output is coherent with
# that of the input
inp = TensorType('float64', (False, False, False, False))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, False, False, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, False, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, True, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, True)
inp = TensorType('float64', (True, False, True, True))()
out = flatten(inp, outdim=3)
assert out.broadcastable == (True, False, True)
def test_flatten_outdim_invalid(): def test_flatten_outdim_invalid():
a = dmatrix() a = dmatrix()
try: try:
...@@ -5500,7 +5525,7 @@ class TestNdGrid(unittest.TestCase): ...@@ -5500,7 +5525,7 @@ class TestNdGrid(unittest.TestCase):
nmgrid = (numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.], nmgrid = (numpy.mgrid[0:1:.1, 1:10:1., 10:100:10.],
numpy.mgrid[0:2:1, 1:10:1, 10:100:10]) numpy.mgrid[0:2:1, 1:10:1, 10:100:10])
tmgrid = (mgrid[0:1:.1, 1:10:1., 10:100:10.], tmgrid = (mgrid[0:1:.1, 1:10:1., 10:100:10.],
mgrid[0:2:1, 1:10:1, 10:100:10]) mgrid[0:2:1, 1:10:1, 10:100:10])
for n, t in zip(nmgrid, tmgrid): for n, t in zip(nmgrid, tmgrid):
for ng, tg in zip(n, t): for ng, tg in zip(n, t):
utt.assert_allclose(ng, tg.eval()) utt.assert_allclose(ng, tg.eval())
...@@ -5508,7 +5533,7 @@ class TestNdGrid(unittest.TestCase): ...@@ -5508,7 +5533,7 @@ class TestNdGrid(unittest.TestCase):
def test_ogrid_numpy_equiv(self): def test_ogrid_numpy_equiv(self):
nogrid = (numpy.ogrid[0:1:.1, 1:10:1., 10:100:10.], nogrid = (numpy.ogrid[0:1:.1, 1:10:1., 10:100:10.],
numpy.ogrid[0:2:1, 1:10:1, 10:100:10]) numpy.ogrid[0:2:1, 1:10:1, 10:100:10])
togrid = (ogrid[0:1:.1, 1:10:1., 10:100:10.], togrid = (ogrid[0:1:.1, 1:10:1., 10:100:10.],
ogrid[0:2:1, 1:10:1, 10:100:10]) ogrid[0:2:1, 1:10:1, 10:100:10])
for n, t in zip(nogrid, togrid): for n, t in zip(nogrid, togrid):
for ng, tg in zip(n, t): for ng, tg in zip(n, t):
...@@ -5539,8 +5564,8 @@ class TestNdGrid(unittest.TestCase): ...@@ -5539,8 +5564,8 @@ class TestNdGrid(unittest.TestCase):
for n, t in zip((nfogrid,niogrid), (ff(0, 10, 10.),fi(0, 10, 10))): for n, t in zip((nfogrid,niogrid), (ff(0, 10, 10.),fi(0, 10, 10))):
for ng, tg in zip(n, t): for ng, tg in zip(n, t):
utt.assert_allclose(ng, tg) utt.assert_allclose(ng, tg)
class TestInversePermutation(unittest.TestCase): class TestInversePermutation(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
...@@ -7699,7 +7724,7 @@ def test_allocempty(): ...@@ -7699,7 +7724,7 @@ def test_allocempty():
f = theano.function([], AllocEmpty("float32")(2, 3)) f = theano.function([], AllocEmpty("float32")(2, 3))
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
out = f() out = f()
assert out.shape == (2, 3) assert out.shape == (2, 3)
assert out.dtype == 'float32' assert out.dtype == 'float32'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论