提交 bb4e32e1 authored 作者: Frederic Bastien's avatar Frederic Bastien

In the reshape op, put the output broadcast flag correctly when we receive an…

In the reshape op, put the output broadcast flag correctly when we receive an int 1 as the new shape.
上级 e57bfe5f
......@@ -3592,6 +3592,7 @@ class Reshape(Op):
shp_list = [shp_orig]
for index in xrange(self.ndim):
y = shp_list[index]
y = as_tensor_variable(y)
# Try to see if we can infer that y has a constant value of 1.
# If so, that dimension should be broadcastable.
try:
......
......@@ -2808,14 +2808,14 @@ def test_reshape():
c = reshape(b, as_tensor_variable(6), ndim=1)
f = inplace_func([b], c)
assert numpy.all(f(numpy.asarray([[0,1,2],[3,4,5]])) == numpy.asarray([0,1,2,3,4,5]))
print f.maker.env.toposort()
#print f.maker.env.toposort()
#check that we remove the useless reshape
#basic to 1 dim(with list)
c = reshape(b, (as_tensor_variable(6),), ndim=1)
f = inplace_func([b], c)
assert numpy.all(f(numpy.asarray([[0,1,2],[3,4,5]])) == numpy.asarray([0,1,2,3,4,5]))
print f.maker.env.toposort()
#print f.maker.env.toposort()
#check that we remove the useless reshape
#basic to shape object of same ndim
......@@ -2860,6 +2860,13 @@ def test_reshape():
#assert numpy.all(f_sub(a_val,numpy.asarray([[0,1],[2,3],[4,5]]))==[2,3])#work in FAST_RUN, but fail on other!
#assert numpy.all(f_sub(a_val,numpy.asarray([[0,1],[2,3],[4,5],[6,7]]))==[2,3])#work in FAST_RUN, but fail on other!
# test broadcast flag for constant value of 1
c = reshape(b, (b.shape[0],b.shape[1],1))
f = inplace_func([b], c)
assert numpy.all(f(numpy.asarray([[0,1,2],[3,4,5]])) == numpy.asarray([[[0],[1],[2]],[[3],[4],[5]]]))
assert f.maker.env.toposort()[-2].outputs[0].type.broadcastable==(False, False, True)
assert numpy.all(f_sub(a_val,b_val)==[2,3])
def test_make_column_matrix_broadcastable():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论