提交 77634eed authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4843 from lamblin/useless_reshape

Remove useless reshape
......@@ -21,6 +21,7 @@ from ..basic_ops import (
host_from_gpu, HostFromGpu, GpuFromHost, GpuReshape, GpuToGpu,
GpuAlloc, GpuAllocEmpty, GpuContiguous,
gpu_join, GpuJoin, GpuSplit, GpuEye, gpu_contiguous)
from ..elemwise import GpuDimShuffle, GpuElemwise
from ..subtensor import GpuSubtensor
from .config import mode_with_gpu, mode_without_gpu, test_ctx_name
......@@ -324,7 +325,7 @@ class G_reshape(test_basic.T_reshape):
mode=mode_with_gpu,
ignore_topo=(HostFromGpu, GpuFromHost,
theano.compile.DeepCopyOp,
theano.gpuarray.elemwise.GpuElemwise,
GpuDimShuffle, GpuElemwise,
theano.tensor.opt.Shape_i,
theano.tensor.opt.MakeVector))
assert self.op == GpuReshape
......
......@@ -39,7 +39,7 @@ from theano import scalar
from theano.scalar import basic
from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file
from theano.compile.ops import Shape_i
from theano.compile.ops import Shape, Shape_i
from theano.tensor.type import (values_eq_approx_remove_inf,
values_eq_approx_remove_nan,
values_eq_approx_remove_inf_nan)
......@@ -48,7 +48,8 @@ from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer)
from theano.gof import toolbox
from theano.tensor.basic import (Alloc, get_scalar_constant_value, ShapeError,
extract_constant, NotScalarConstantError)
extract_constant, NotScalarConstantError,
Reshape)
from six import StringIO
_logger = logging.getLogger('theano.tensor.opt')
......@@ -574,25 +575,6 @@ def local_dimshuffle_lift(node):
"""
op = node.op
if (isinstance(op, T.Reshape) and
node.inputs[0].owner is not None and
isinstance(node.inputs[0].owner.op, DimShuffle)):
new_order = node.inputs[0].owner.op.new_order
new_order = [i for i in new_order if i != 'x']
input = node.inputs[0].owner.inputs[0]
broadcastables = input.broadcastable
new_order_of_nonbroadcastables = []
for i, bd in zip(new_order, broadcastables):
if not bd:
new_order_of_nonbroadcastables.append(i)
no_change_in_order = all(
new_order_of_nonbroadcastables[i] <= new_order_of_nonbroadcastables[i + 1]
for i in xrange(len(new_order_of_nonbroadcastables) - 1))
if no_change_in_order:
shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(input, shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
if not isinstance(op, DimShuffle):
return False
......@@ -626,6 +608,42 @@ def local_dimshuffle_lift(node):
return [ret]
@register_canonicalize
@gof.local_optimizer([Reshape])
def local_useless_dimshuffle_in_reshape(node):
"""
Removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
op = node.op
if not isinstance(op, Reshape):
return False
if not (node.inputs[0].owner is not None and
isinstance(node.inputs[0].owner.op, DimShuffle)):
return False
new_order = node.inputs[0].owner.op.new_order
input = node.inputs[0].owner.inputs[0]
broadcastables = node.inputs[0].broadcastable
new_order_of_nonbroadcast = []
for i, bd in zip(new_order, broadcastables):
if not bd:
new_order_of_nonbroadcast.append(i)
no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
for i in xrange(len(new_order_of_nonbroadcast) - 1))
if no_change_in_order:
shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(input, shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize
@gof.local_optimizer([T.DimShuffle])
def local_lift_transpose_through_dot(node):
......@@ -4165,15 +4183,135 @@ register_canonicalize(local_reshape_chain(T.Reshape),
@gof.local_optimizer([T.Reshape])
def local_useless_reshape(node):
"""
Remove Reshape when both the input and the output have a
single dimension.
Remove two kinds of useless reshape.
Remove Reshape when both the input and output have a single dimension.
Remove Reshape when reshaping to the shape of the input.
"""
if isinstance(node.op, T.Reshape):
if (node.inputs[0].ndim == 1 and node.outputs[0].ndim == 1 and
node.inputs[0].broadcastable ==
node.outputs[0].broadcastable):
return [node.inputs[0]]
op = node.op
if not isinstance(op, Reshape):
return False
input = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
if input.ndim != output.ndim:
return False
# Simple case: both input and output have a single dimension.
# This could hide errors if the user provides inconsistent shapes.
if (input.ndim == 1 and output.ndim == 1 and
input.broadcastable == output.broadcastable):
return [input]
# Second case: all the shapes match the input shape
# Match Reshape(x, x.shape)
if output_shape.owner and isinstance(output_shape.owner.op, Shape):
shape_input = output_shape.owner.inputs[0]
if shape_input == input:
return [input]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
output_shape_is = output_shape.owner.inputs
if not hasattr(node, 'fgraph'):
shape_feature = None
else:
shape_feature = getattr(node.fgraph, 'shape_feature', None)
shape_match = [False] * input.ndim
for dim in xrange(input.ndim):
outshp_i = output_shape_is[dim]
# Match Shape_i{dim}(input)
if (outshp_i.owner and isinstance(outshp_i.owner.op, Shape_i) and
outshp_i.owner.op.i == dim and
outshp_i.owner.inputs[0] == input):
shape_match[dim] = True
continue
# Match Shape(input)[dim]
if (outshp_i.owner and isinstance(outshp_i.owner.op, Subtensor) and
len(outshp_i.owner.inputs) == 2 and
extract_constant(outshp_i.owner.inputs[1]) == dim):
subtensor_inp = outshp_i.owner.inputs[0]
if (subtensor_inp.owner and
isinstance(subtensor_inp.owner.op, Shape)):
shape_input_i = subtensor_inp.owner.inputs[0]
if shape_input_i == input:
shape_match[dim] = True
continue
# Match 1 if input.broadcastable[dim] is True
if (input.broadcastable[dim] and
extract_constant(outshp_i, only_process_constants=1) == 1):
shape_match[dim] = True
continue
# Match shape_of[input][dim] or its constant equivalent
if shape_feature:
inpshp_i = shape_feature.get_shape(input, dim)
if (inpshp_i == outshp_i or
(extract_constant(inpshp_i, only_process_constants=1) ==
extract_constant(outshp_i, only_process_constants=1))):
shape_match[dim] = True
continue
if all(shape_match):
return [input]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
@register_canonicalize
@gof.local_optimizer([T.Reshape])
def local_reshape_to_dimshuffle(node):
"""
Broadcastable dimensions in Reshape are replaced with dimshuffle.
The goal is to avoid using reshape to add or remove broadcastable
dimensions, but use dimshuffle instead, so dimshuffles can cancel out
or be removed later on.
For example:
- reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,))
- reshape(x, (1, m, 1, n, 1, 1))
--> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n)))
"""
op = node.op
if not isinstance(op, Reshape):
return False
input = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
dimshuffle_new_order = []
new_output_shape = []
index = 0 # index over the output of the new reshape
for i in xrange(output.ndim):
# Since output_shape is a symbolic vector, we trust extract_constant
# to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that.
dim = extract_constant(output_shape[i], only_process_constants=False,
elemwise=False)
if dim == 1:
dimshuffle_new_order.append('x')
else:
dimshuffle_new_order.append(index)
new_output_shape.append(dim)
index = index + 1
if index != output.ndim:
inner = op.__class__(len(new_output_shape))(input, new_output_shape)
copy_stack_trace(output, inner)
new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
copy_stack_trace(output, new_node)
return new_node
@register_canonicalize
......
......@@ -5132,14 +5132,18 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin):
self.ignore_topo = ignore_topo
super(T_reshape, self).__init__(name)
def function(self, inputs, outputs):
def function(self, inputs, outputs, ignore_empty=False):
f = function(inputs, outputs, mode=self.mode)
if self.mode is not None or theano.config.mode != "FAST_COMPILE":
topo = f.maker.fgraph.toposort()
topo_ = [node for node in topo if not isinstance(node.op,
self.ignore_topo)]
assert len(topo_) == 1, topo_
assert type(topo_[0].op) is self.op
if ignore_empty:
assert len(topo_) <= 1, topo_
else:
assert len(topo_) == 1, topo_
if len(topo_) > 0:
assert type(topo_[0].op) is self.op
return f
def test_reshape(self):
......@@ -5212,10 +5216,21 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin):
# test broadcast flag for constant value of 1
c = reshape(b, (b.shape[0], b.shape[1], 1))
f = self.function([b], c)
# That reshape may get replaced with a dimshuffle, with is ignored,
# so we pass "ignore_empty=True"
f = self.function([b], c, ignore_empty=True)
assert numpy.all(f(numpy.asarray([[0, 1, 2], [3, 4, 5]])) ==
numpy.asarray([[[0], [1], [2]], [[3], [4], [5]]]))
assert (f.maker.fgraph.toposort()[-2].outputs[0].type.broadcastable ==
assert (f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable ==
(False, False, True))
# test broadcast flag for constant value of 1 if it cannot be
# replaced with dimshuffle
c = reshape(b, (b.shape[1], b.shape[0], 1))
f = self.function([b], c, ignore_empty=True)
assert numpy.all(f(numpy.asarray([[0, 1, 2], [3, 4, 5]])) ==
numpy.asarray([[[0], [1]], [[2], [3]], [[4], [5]]]))
assert (f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable ==
(False, False, True))
def test_m1(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论