提交 f71619bb authored 作者: Frederic Bastien's avatar Frederic Bastien

Added optimization to remove useless subtensor.

上级 fb182b59
......@@ -1088,6 +1088,21 @@ def local_upcast_elemwise_constant_inputs(node):
# Subtensor opts #
##################
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Subtensor])
def local_useless_subtensor(node):
"""
Remove Subtensor if it take the full input
"""
if (isinstance(node.op, T.Subtensor) and
all([isinstance(idx, slice) and
idx.start in [0, None] and
idx.stop in [sys.maxint, None]
and idx.step in [1, None] for idx in node.op.idx_list])):
x = node.inputs[0]
return [x]
@register_canonicalize
@gof.local_optimizer([])
def local_subtensor_lift(node):
......
......@@ -1145,6 +1145,15 @@ def test_log_add():
#TODO: (write and) test that the optimization works with Sum in addition to working with Add.
def test_local_useless_subtensor():
x = TT.matrix('x')
f = function([x], TT.exp(x)[0:], mode=mode_opt)
prog=f.maker.env.toposort()
assert prog[0].op == TT.exp
assert len(prog)==1
f([[0,1],[2,3]]) # let debugmode test something
class test_local_subtensor_lift(unittest.TestCase):
def test0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论