提交 70e23f42 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

move useless_dimshuffle_in_reshape into its own opt

and out of dimshuffle_lifter
上级 8db5dc89
......@@ -559,7 +559,7 @@ def is_dimshuffle_useless(new_order, input):
return is_useless
@gof.local_optimizer([DimShuffle, Reshape])
@gof.local_optimizer([DimShuffle])
def local_dimshuffle_lift(node):
"""
"Lifts" DimShuffle through Elemwise operations and merges
......@@ -573,33 +573,8 @@ def local_dimshuffle_lift(node):
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
Also removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
op = node.op
if (isinstance(op, Reshape) and
node.inputs[0].owner is not None and
isinstance(node.inputs[0].owner.op, DimShuffle)):
new_order = node.inputs[0].owner.op.new_order
input = node.inputs[0].owner.inputs[0]
broadcastables = node.inputs[0].broadcastable
new_order_of_nonbroadcastables = []
for i, bd in zip(new_order, broadcastables):
if not bd:
new_order_of_nonbroadcastables.append(i)
no_change_in_order = all(
new_order_of_nonbroadcastables[i] <= new_order_of_nonbroadcastables[i + 1]
for i in xrange(len(new_order_of_nonbroadcastables) - 1))
if no_change_in_order:
shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(input, shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
if not isinstance(op, DimShuffle):
return False
......@@ -633,6 +608,42 @@ def local_dimshuffle_lift(node):
return [ret]
@register_canonicalize
@gof.local_optimizer([Reshape])
def local_useless_dimshuffle_in_reshape(node):
"""
Removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
op = node.op
if not isinstance(op, Reshape):
return False
if not (node.inputs[0].owner is not None and
isinstance(node.inputs[0].owner.op, DimShuffle)):
return False
new_order = node.inputs[0].owner.op.new_order
input = node.inputs[0].owner.inputs[0]
broadcastables = node.inputs[0].broadcastable
new_order_of_nonbroadcast = []
for i, bd in zip(new_order, broadcastables):
if not bd:
new_order_of_nonbroadcast.append(i)
no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
for i in xrange(len(new_order_of_nonbroadcast) - 1))
if no_change_in_order:
shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(input, shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize
@gof.local_optimizer([T.DimShuffle])
def local_lift_transpose_through_dot(node):
......
......@@ -12,7 +12,7 @@ import unittest
import numpy
from six.moves import xrange
from nose.plugins.skip import SkipTest
from nose.tools import assert_raises
from nose.tools import assert_raises, assert_true
from numpy.testing import dec
from numpy.testing.noseclasses import KnownFailureTest
......@@ -32,6 +32,7 @@ import theano.tensor.opt as opt
from theano.tensor.opt import (
local_add_specialize,
local_dimshuffle_lift,
local_useless_dimshuffle_in_reshape,
local_useless_alloc,
local_greedy_distributor,
local_useless_reshape,
......@@ -223,32 +224,34 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_useless_dimshuffle_in_presence_of_reshape(self):
vector = TensorType(broadcastable=(False,), dtype='float64')('vector')
mat = TensorType(broadcastable=(False, False), dtype='float64')('mat')
row = TensorType(broadcastable=(True, False), dtype='float64')('row')
col = TensorType(broadcastable=(False, True), dtype='float64')('col')
reshape_dimshuffle_vector = tensor.reshape(vector.dimshuffle('x', 0), vector.shape)
reshape_dimshuffle_mat = tensor.reshape(mat.dimshuffle('x', 0, 'x', 1), mat.shape)
reshape_dimshuffle_row = tensor.reshape(row.dimshuffle(1, 'x'), row.shape)
reshape_dimshuffle_col = tensor.reshape(col.dimshuffle(0), col.shape)
g = FunctionGraph([vector, mat, row, col],
[reshape_dimshuffle_vector, reshape_dimshuffle_mat,
reshape_dimshuffle_row, reshape_dimshuffle_col])
self.assertTrue(str(g) == "[Reshape{1}(DimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(DimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(DimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(DimShuffle{0}(col), Shape(col))]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col))]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_useless_dimshuffle_in_reshape():
vector = TensorType(broadcastable=(False,), dtype='float64')('vector')
mat = TensorType(broadcastable=(False, False), dtype='float64')('mat')
row = TensorType(broadcastable=(True, False), dtype='float64')('row')
col = TensorType(broadcastable=(False, True), dtype='float64')('col')
reshape_dimshuffle_vector = tensor.reshape(vector.dimshuffle('x', 0), vector.shape)
reshape_dimshuffle_mat = tensor.reshape(mat.dimshuffle('x', 0, 'x', 1), mat.shape)
reshape_dimshuffle_row = tensor.reshape(row.dimshuffle(1, 'x'), row.shape)
reshape_dimshuffle_col = tensor.reshape(col.dimshuffle(0), col.shape)
g = FunctionGraph([vector, mat, row, col],
[reshape_dimshuffle_vector, reshape_dimshuffle_mat,
reshape_dimshuffle_row, reshape_dimshuffle_col])
assert_true(str(g) == "[Reshape{1}(DimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(DimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(DimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(DimShuffle{0}(col), Shape(col))]")
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.optimize(g)
assert_true(str(g) == "[Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col))]")
# Check stacktrace was copied over correctly after opt was applied
assert_true(hasattr(g.outputs[0].tag, 'trace'))
def test_add_canonizer_problem0():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论