提交 a2bd7dcd authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix Reshape.infer_shape when one of the given shapes is -1

上级 6b16e20c
...@@ -3788,9 +3788,18 @@ class Reshape(Op): ...@@ -3788,9 +3788,18 @@ class Reshape(Op):
g_out, = grads g_out, = grads
return [reshape(g_out, shape(x), ndim=x.ndim), None] return [reshape(g_out, shape(x), ndim=x.ndim), None]
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
#we can't just put node.inputs[1] as not all op support interation # inputs[1] can contain at most one value of '-1', meaning the actual
#and this is needed in the ShapeOptimizer # shape of the output will be automatically computed by reshape, so
return [tuple([node.inputs[1][i] for i in range(self.ndim)])] # that the total number of elements stays the same.
# TODO: Maybe put that formula here?
# It's not trivial, because we would have to check if the product of
# all the non-minus-one shapes is a divisor of the product of the
# original shapes.
return [tuple([switch(eq(node.inputs[1][i], -1),
theano.tensor.opt.Shape_i(i)(node.outputs[0]),
node.inputs[1][i])
for i in range(self.ndim)]
)]
def reshape(x, newshape, ndim=None, name=None): def reshape(x, newshape, ndim=None, name=None):
if ndim is None: if ndim is None:
......
...@@ -3290,9 +3290,11 @@ class T_op_cache(unittest.TestCase): ...@@ -3290,9 +3290,11 @@ class T_op_cache(unittest.TestCase):
a = numpy.random.rand(5,2).astype(config.floatX) a = numpy.random.rand(5,2).astype(config.floatX)
self.assertTrue(numpy.all(fn_py(a) == fn_c_or_py(a))) self.assertTrue(numpy.all(fn_py(a) == fn_c_or_py(a)))
class T_reshape(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_reshape(): def test_reshape(self):
a = dvector() a = dvector()
b = dmatrix() b = dmatrix()
d = dmatrix() d = dmatrix()
...@@ -3361,9 +3363,30 @@ def test_reshape(): ...@@ -3361,9 +3363,30 @@ def test_reshape():
assert numpy.all(f(numpy.asarray([[0,1,2],[3,4,5]])) == numpy.asarray([[[0],[1],[2]],[[3],[4],[5]]])) assert numpy.all(f(numpy.asarray([[0,1,2],[3,4,5]])) == numpy.asarray([[[0],[1],[2]],[[3],[4],[5]]]))
assert f.maker.env.toposort()[-2].outputs[0].type.broadcastable==(False, False, True) assert f.maker.env.toposort()[-2].outputs[0].type.broadcastable==(False, False, True)
assert numpy.all(f_sub(a_val,b_val)==[2,3]) assert numpy.all(f_sub(a_val,b_val)==[2,3])
def test_infer_shape(self):
a = matrix('a')
shapes = ivector('shapes')
ndim = 2
r = a.reshape(shapes, ndim=2)
z = zeros_like(r)
f = function([a, shapes], z.shape)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.uniform(size=(3,4)).astype(config.floatX)
self.assertTrue((f(a_val, [4, 3]) == [4, 3]).all())
self.assertTrue((f(a_val, [-1, 3]) == [4, 3]).all())
self.assertTrue((f(a_val, [4, -1]) == [4, 3]).all())
self.assertRaises(ValueError, f, a_val, [-1, 5])
self.assertRaises(ValueError, f, a_val, [7, -1])
self.assertRaises(ValueError, f, a_val, [7, 5])
self.assertRaises(ValueError, f, a_val, [-1, -1])
def test_make_column_matrix_broadcastable(): def test_make_column_matrix_broadcastable():
# The goal of the operation made by `b` is to ensure the second dimension # The goal of the operation made by `b` is to ensure the second dimension
# of the column matrix is broadcastable. # of the column matrix is broadcastable.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论