提交 a074ec4a authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2515 from sisp/local_useless_subtensor

AdvancedSubtensor1 support for local_useless_subtensor optimization
...@@ -1917,16 +1917,21 @@ def local_set_to_inc_subtensor(node): ...@@ -1917,16 +1917,21 @@ def local_set_to_inc_subtensor(node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_useless_subtensor(node): def local_useless_subtensor(node):
""" """
Remove Subtensor if it takes the full input Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the
AdvancedSubtensor1 case, the full input is taken when the indices are
equivalent to `arange(0, input.shape[0], 1)` using either an explicit
list/vector or the ARange op.
""" """
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'):
return
shape_of = node.fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'):
return
shape_of = node.fgraph.shape_feature.shape_of
cdata = node.op.get_constant_idx(node.inputs, allow_partial=True) cdata = node.op.get_constant_idx(node.inputs, allow_partial=True)
for pos, idx in enumerate(cdata): for pos, idx in enumerate(cdata):
if not isinstance(idx, slice): if not isinstance(idx, slice):
...@@ -1985,8 +1990,43 @@ def local_useless_subtensor(node): ...@@ -1985,8 +1990,43 @@ def local_useless_subtensor(node):
pass pass
else: else:
return False return False
elif isinstance(node.op, AdvancedSubtensor1):
# get length of the indexed tensor along the first axis
try:
length = get_scalar_constant_value(shape_of[node.inputs[0]][0])
except NotScalarConstantError:
return False
# get index (which must be a vector by definition)
idx = node.inputs[1]
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
# this optimization
if isinstance(idx, T.Constant):
idx = idx.value
if len(idx) != length:
return False
if numpy.any(idx != numpy.arange(length)):
return False
elif idx.owner is not None and isinstance(idx.owner.op, T.ARange):
try:
start, stop, step = map(get_scalar_constant_value,
idx.owner.inputs)
except NotScalarConstantError:
return False
if start != 0:
return False
if stop != length:
return False
if step != 1:
return False
else:
return False
else:
return False
return [node.inputs[0]] return [node.inputs[0]]
@register_canonicalize @register_canonicalize
......
...@@ -45,6 +45,7 @@ from theano.tensor import vector, ivector, lvector, fvector, dvector ...@@ -45,6 +45,7 @@ from theano.tensor import vector, ivector, lvector, fvector, dvector
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix
from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices
from theano.tensor import ( from theano.tensor import (
AdvancedSubtensor1,
as_tensor_variable, as_tensor_variable,
inplace, inplace,
Join, Join,
...@@ -1713,6 +1714,29 @@ def test_local_useless_subtensor(): ...@@ -1713,6 +1714,29 @@ def test_local_useless_subtensor():
f([[1, 2, 3], [4, 5, 6]], 1) f([[1, 2, 3], [4, 5, 6]], 1)
f([[1, 2, 3], [4, 5, 6]], 3) f([[1, 2, 3], [4, 5, 6]], 3)
# Test AdvancedSubtensor1 case when all rows are selected by a list/vector
# or ARange op
for dims, res in (([0, 1], True),
([1, 0], False),
([0, 0], False),
([0, 0, 1], False),
(T.arange(2), True),
(T.arange(0, 2), True),
(T.arange(0, 2, 2), False),
(T.arange(0, 2, -1), False),
(T.arange(1, 2), False)):
f = function([x], tensor.exp(x_c).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
prog = f.maker.fgraph.toposort()
if res:
assert isinstance(prog[0].op, theano.tensor.SpecifyShape), dims
assert prog[1].op == tensor.exp, dims
assert len(prog) == 2, dims
else:
assert any([isinstance(node.op, AdvancedSubtensor1)
for node in prog])
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
class test_local_subtensor_make_vector(unittest.TestCase): class test_local_subtensor_make_vector(unittest.TestCase):
def test_scalar_idx(self): def test_scalar_idx(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论