提交 1ede2527 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed local_useless_subtensor optimization

上级 371c37b8
......@@ -1245,32 +1245,59 @@ def local_useless_subtensor(node):
shape_of = node.env.shape_feature.shape_of
node_input_idx = 1
for pos, idx in enumerate(node.op.idx_list):
if not isinstance(idx, slice):
# If idx is not a slice, this means we remove this dimension
# from the output, so the subtensor is not useless
return False
if idx.start not in [0,None]:
# If the start of the slice is different from 0, or is a
# variable, then we assume the subtensor is not useless
return False
if idx.step not in [1, None]:
# If we are going backwards, or skipping elements, then this
# is not a useless subtensor
return False
length_pos_data = sys.maxint
length_pos_shape_i = None
try:
length_pos = shape_of[node.inputs[0]][pos]
if isinstance(length_pos, theano.tensor.basic.TensorConstant):
length_pos_data = length_pos.data
else:
length_pos_shape_i = node.inputs[node_input_idx].owner.inputs[0]
try:
length_pos_data = get_constant_value(length_pos)
except TypeError:
pass
if isinstance(idx.stop, theano.scalar.Scalar):
if isinstance(node.inputs[node_input_idx].owner.op,
T.ScalarFromTensor):
length_pos_shape_i = node.inputs[node_input_idx].owner.inputs[0]
else:
length_pos_shape_i = node.inputs[node_input_idx]
assert length_pos_shape_i.type == idx.stop
# We already know that start and step are not variables
# and so they don't appear in the input of the node
node_input_idx += 1
# Catch exception from shape_of
except Exception, e:
length_pos = None
if ( isinstance(idx,slice) and
idx.start in [0,None] and
idx.step in [1,None] and
(idx.stop in [sys.maxint, None, length_pos_data] or
(isinstance(idx.stop, int) and idx.stop>=length_pos_data) or
(isinstance(idx.stop, theano.scalar.Scalar) and
length_pos==length_pos_shape_i)
)):
if isinstance(idx.stop, int):
if idx.stop < length_pos_data:
return False
elif isinstance(idx.stop, theano.scalar.Scalar):
if length_pos_shape_i is None:
return False
if length_pos is None:
return False
if length_pos_shape_i != length_pos:
return False
elif idx.stop is None:
pass
else:
return False
if isinstance(idx, slice):
node_input_idx += sum([isinstance(idx.start, theano.scalar.Scalar),
isinstance(idx.stop, theano.scalar.Scalar),
isinstance(idx.step, theano.scalar.Scalar)])
return [node.inputs[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论