提交 350f5b6b authored 作者: Frederic Bastien's avatar Frederic Bastien

fix bug(it disabled some fonctionality) introduced in commit 54f8821476cc. Add…

fix bug(it disabled some fonctionality) introduced in commit 54f8821476cc. Add test to catch it directly.
上级 4b1e64ca
...@@ -3450,8 +3450,11 @@ class Reshape(Op): ...@@ -3450,8 +3450,11 @@ class Reshape(Op):
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcast)]) return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcast)])
else: else:
bcasts = [False] * self.ndim bcasts = [False] * self.ndim
shp_list = shp_orig
if not isinstance(shp_orig,(list,tuple)):
shp_list = [shp_orig]
for index in xrange(self.ndim): for index in xrange(self.ndim):
y = shp_orig[index] y = shp_list[index]
# Try to see if we can infer that y has a constant value of 1. # Try to see if we can infer that y has a constant value of 1.
# If so, that dimension should be broadcastable. # If so, that dimension should be broadcastable.
try: try:
......
...@@ -2531,9 +2531,15 @@ def test_reshape(): ...@@ -2531,9 +2531,15 @@ def test_reshape():
a = dvector() a = dvector()
b = dmatrix() b = dmatrix()
#basic to 1 dim
c = reshape(b, as_tensor_variable(6), ndim=1)
f = inplace_func([b], c)
assert numpy.all(f(numpy.asarray([[0,1,2],[3,4,5]])) == numpy.asarray([0,1,2,3,4,5]))
print f.maker.env.toposort()
#check that we remove the useless reshape
#basic to 2 dims
c = reshape(a, [2,3]) c = reshape(a, [2,3])
#basic
f = inplace_func([a], c) f = inplace_func([a], c)
assert numpy.all(f(numpy.asarray([0,1,2,3,4,5])) == numpy.asarray([[0,1,2], [3,4,5]])) assert numpy.all(f(numpy.asarray([0,1,2,3,4,5])) == numpy.asarray([[0,1,2], [3,4,5]]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论