提交 a074ec4a authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2515 from sisp/local_useless_subtensor

AdvancedSubtensor1 support for local_useless_subtensor optimization
......@@ -1917,16 +1917,21 @@ def local_set_to_inc_subtensor(node):
@register_canonicalize
@register_specialize
@gof.local_optimizer([Subtensor])
@gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_useless_subtensor(node):
"""
Remove Subtensor if it takes the full input
Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the
AdvancedSubtensor1 case, the full input is taken when the indices are
equivalent to `arange(0, input.shape[0], 1)` using either an explicit
list/vector or the ARange op.
"""
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'):
return
shape_of = node.fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'):
return
shape_of = node.fgraph.shape_feature.shape_of
cdata = node.op.get_constant_idx(node.inputs, allow_partial=True)
for pos, idx in enumerate(cdata):
if not isinstance(idx, slice):
......@@ -1985,8 +1990,43 @@ def local_useless_subtensor(node):
pass
else:
return False
elif isinstance(node.op, AdvancedSubtensor1):
# get length of the indexed tensor along the first axis
try:
length = get_scalar_constant_value(shape_of[node.inputs[0]][0])
except NotScalarConstantError:
return False
# get index (which must be a vector by definition)
idx = node.inputs[1]
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
# this optimization
if isinstance(idx, T.Constant):
idx = idx.value
if len(idx) != length:
return False
if numpy.any(idx != numpy.arange(length)):
return False
elif idx.owner is not None and isinstance(idx.owner.op, T.ARange):
try:
start, stop, step = map(get_scalar_constant_value,
idx.owner.inputs)
except NotScalarConstantError:
return False
if start != 0:
return False
if stop != length:
return False
if step != 1:
return False
else:
return False
else:
return False
return [node.inputs[0]]
return [node.inputs[0]]
@register_canonicalize
......
......@@ -45,6 +45,7 @@ from theano.tensor import vector, ivector, lvector, fvector, dvector
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix
from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices
from theano.tensor import (
AdvancedSubtensor1,
as_tensor_variable,
inplace,
Join,
......@@ -1713,6 +1714,29 @@ def test_local_useless_subtensor():
f([[1, 2, 3], [4, 5, 6]], 1)
f([[1, 2, 3], [4, 5, 6]], 3)
# Test AdvancedSubtensor1 case when all rows are selected by a list/vector
# or ARange op
for dims, res in (([0, 1], True),
([1, 0], False),
([0, 0], False),
([0, 0, 1], False),
(T.arange(2), True),
(T.arange(0, 2), True),
(T.arange(0, 2, 2), False),
(T.arange(0, 2, -1), False),
(T.arange(1, 2), False)):
f = function([x], tensor.exp(x_c).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
prog = f.maker.fgraph.toposort()
if res:
assert isinstance(prog[0].op, theano.tensor.SpecifyShape), dims
assert prog[1].op == tensor.exp, dims
assert len(prog) == 2, dims
else:
assert any([isinstance(node.op, AdvancedSubtensor1)
for node in prog])
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
class test_local_subtensor_make_vector(unittest.TestCase):
def test_scalar_idx(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论