提交 d4126ac2 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki 提交者: Pascal Lamblin

looping over dimensions

上级 763a048a
......@@ -4184,16 +4184,16 @@ def local_useless_reshape(node):
dimshuffle_new_order = []
new_output_shape = []
i = 0 # index over the output of the new reshape
import ipdb; ipdb.set_trace()
for dim in extract_constant(output_shape, only_process_constants=True):
index = 0 # index over the output of the new reshape
for i in xrange(output.ndim):
dim = extract_constant(output_shape[i], only_process_constants=False)
if dim == 1:
dimshuffle_new_order.append('x')
else:
dimshuffle_new_order.append(i)
dimshuffle_new_order.append(index)
new_output_shape.append(dim)
i = i + 1
if i != output.ndim:
index = index + 1
if index != output.ndim:
inner = op.__class__(len(new_output_shape))(input, new_output_shape)
return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
......
......@@ -6194,17 +6194,16 @@ class Test_local_useless_reshape(unittest.TestCase):
"TensorConstant{[1 5 1 6 1 1]})]"))
reshape_lift.optimize(g)
import ipdb; ipdb.set_trace()
self.assertTrue(str(g) == "[DimShuffle{x,0}"
"(Reshape{2}(<TensorType(float64, vector)>, "
"TensorConstant{4})), "
"(<TensorType(float64, vector)>), "
"DimShuffle{x,0,x,1,x,x}"
"Reshape{6}(<TensorType(float64, matrix)>, "
"TensorConstant{[5 6]})]")
"(Reshape{2}(<TensorType(float64, matrix)>, "
"TensorConstant{[5 6]}))]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_local_reshape_lift():
x = tensor.tensor4()
out = T.exp(x).reshape([x.size])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论