提交 d4126ac2 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki 提交者: Pascal Lamblin

looping over dimensions

上级 763a048a
...@@ -4184,16 +4184,16 @@ def local_useless_reshape(node): ...@@ -4184,16 +4184,16 @@ def local_useless_reshape(node):
dimshuffle_new_order = [] dimshuffle_new_order = []
new_output_shape = [] new_output_shape = []
i = 0 # index over the output of the new reshape index = 0 # index over the output of the new reshape
import ipdb; ipdb.set_trace() for i in xrange(output.ndim):
for dim in extract_constant(output_shape, only_process_constants=True): dim = extract_constant(output_shape[i], only_process_constants=False)
if dim == 1: if dim == 1:
dimshuffle_new_order.append('x') dimshuffle_new_order.append('x')
else: else:
dimshuffle_new_order.append(i) dimshuffle_new_order.append(index)
new_output_shape.append(dim) new_output_shape.append(dim)
i = i + 1 index = index + 1
if i != output.ndim: if index != output.ndim:
inner = op.__class__(len(new_output_shape))(input, new_output_shape) inner = op.__class__(len(new_output_shape))(input, new_output_shape)
return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
......
...@@ -6194,17 +6194,16 @@ class Test_local_useless_reshape(unittest.TestCase): ...@@ -6194,17 +6194,16 @@ class Test_local_useless_reshape(unittest.TestCase):
"TensorConstant{[1 5 1 6 1 1]})]")) "TensorConstant{[1 5 1 6 1 1]})]"))
reshape_lift.optimize(g) reshape_lift.optimize(g)
import ipdb; ipdb.set_trace()
self.assertTrue(str(g) == "[DimShuffle{x,0}" self.assertTrue(str(g) == "[DimShuffle{x,0}"
"(Reshape{2}(<TensorType(float64, vector)>, " "(<TensorType(float64, vector)>), "
"TensorConstant{4})), "
"DimShuffle{x,0,x,1,x,x}" "DimShuffle{x,0,x,1,x,x}"
"Reshape{6}(<TensorType(float64, matrix)>, " "(Reshape{2}(<TensorType(float64, matrix)>, "
"TensorConstant{[5 6]})]") "TensorConstant{[5 6]}))]")
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace')) self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_local_reshape_lift(): def test_local_reshape_lift():
x = tensor.tensor4() x = tensor.tensor4()
out = T.exp(x).reshape([x.size]) out = T.exp(x).reshape([x.size])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论