提交 64c9ceeb authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki 提交者: Pascal Lamblin

reshape to 1 is replace with dimshuffle

上级 0d21995f
...@@ -48,7 +48,8 @@ from theano.gof.opt import (Optimizer, pre_constant_merge, ...@@ -48,7 +48,8 @@ from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer) pre_greedy_local_optimizer)
from theano.gof import toolbox from theano.gof import toolbox
from theano.tensor.basic import (Alloc, get_scalar_constant_value, ShapeError, from theano.tensor.basic import (Alloc, get_scalar_constant_value, ShapeError,
extract_constant, NotScalarConstantError) extract_constant, NotScalarConstantError,
Reshape)
from six import StringIO from six import StringIO
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
...@@ -4169,11 +4170,31 @@ def local_useless_reshape(node): ...@@ -4169,11 +4170,31 @@ def local_useless_reshape(node):
single dimension. single dimension.
""" """
if isinstance(node.op, T.Reshape): op = node.op
if (node.inputs[0].ndim == 1 and node.outputs[0].ndim == 1 and if not isinstance(op, Reshape):
node.inputs[0].broadcastable == return False
node.outputs[0].broadcastable):
return [node.inputs[0]] input = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
if (input.ndim == 1 and output.ndim == 1 and
input.broadcastable == output.broadcastable):
return [input]
dimshuffle_new_order = []
new_output_shape = []
i = 0 # index over the output of the new reshape
for dim in output_shape.value:
if dim == 1:
dimshuffle_new_order.append('x')
else:
dimshuffle_new_order.append(i)
new_output_shape.append(dim)
i = i + 1
if len(dimshuffle_new_order) > 0:
inner = op.__class__(len(new_output_shape))(input, new_output_shape)
return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
@register_canonicalize @register_canonicalize
......
...@@ -34,6 +34,7 @@ from theano.tensor.opt import ( ...@@ -34,6 +34,7 @@ from theano.tensor.opt import (
local_dimshuffle_lift, local_dimshuffle_lift,
local_useless_alloc, local_useless_alloc,
local_greedy_distributor, local_greedy_distributor,
local_useless_reshape,
mul_canonizer, mul_canonizer,
out2in, out2in,
Shape_i, Shape_i,
...@@ -6164,15 +6165,45 @@ class Test_Reshape(unittest.TestCase): ...@@ -6164,15 +6165,45 @@ class Test_Reshape(unittest.TestCase):
assert sum(isinstance(node.op, self.op) for node in topo) == 1 assert sum(isinstance(node.op, self.op) for node in topo) == 1
def test_local_useless_reshape(): class Test_local_useless_reshape(unittest.TestCase):
mode = theano.compile.get_default_mode().including( def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
def test_0(self):
mode = theano.compile.get_default_mode().including(
'local_useless_reshape') 'local_useless_reshape')
i = T.iscalar('i') i = T.iscalar('i')
m = theano.tensor.mgrid[0:i,] m = theano.tensor.mgrid[0:i,]
f = theano.function([i], m, mode=mode) f = theano.function([i], m, mode=mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo) assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
def test_1(self):
reshape_lift = out2in(local_useless_reshape)
x = shared(self.rng.randn(4,))
y = shared(self.rng.randn(5, 6))
reshape_x = tensor.reshape(x, (1, 4))
reshape_y = tensor.reshape(y, (1, 5, 1, 6, 1, 1))
g = FunctionGraph([x, y], [reshape_x, reshape_y])
self.assertTrue(str(g) == ("[Reshape{2}"
"(<TensorType(float64, vector)>, "
"TensorConstant{[1 4]}), "
"Reshape{6}"
"(<TensorType(float64, matrix)>, "
"TensorConstant{[1 5 1 6 1 1]})]"))
reshape_lift.optimize(g)
import ipdb; ipdb.set_trace()
self.assertTrue(str(g) == "[DimShuffle{x,0}"
"(Reshape{2}(<TensorType(float64, vector)>, "
"TensorConstant{4})), "
"DimShuffle{x,0,x,1,x,x}"
"Reshape{6}(<TensorType(float64, matrix)>, "
"TensorConstant{[5 6]})]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_local_reshape_lift(): def test_local_reshape_lift():
x = tensor.tensor4() x = tensor.tensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论