提交 d50c460b authored 作者: Frederic Bastien's avatar Frederic Bastien

'merge 2 subtensor in this case: var[int:][-1]'

上级 e146275a
...@@ -1080,6 +1080,35 @@ def local_subtensor_lift(node): ...@@ -1080,6 +1080,35 @@ def local_subtensor_lift(node):
new_inputs.append(i.dimshuffle('x'*node.outputs[0].ndim)) new_inputs.append(i.dimshuffle('x'*node.outputs[0].ndim))
return [u.owner.op(*new_inputs)] return [u.owner.op(*new_inputs)]
@register_canonicalize
@register_specialize
@gof.local_optimizer([])
def local_subtensor_merge(node):
"""
var[int:][-1] -> var[-1]
The optimization is valid for any int be it constant or not.
"""
if (isinstance(node.op, T.Subtensor) and
len(node.inputs)==1 and
len(node.op.idx_list)==1 and
node.op.idx_list[0]==-1):
u = node.inputs[0]
if not u.owner or len(u.clients) > 1:
return False
if (isinstance(u.owner.op, T.Subtensor) and
len(u.owner.inputs) in [1,2] and
len(u.owner.op.idx_list)==1 and
isinstance(u.owner.op.idx_list[0], slice) and
isinstance(u.owner.op.idx_list[0].start, (int, scalar.basic.Scalar)) and
u.owner.op.idx_list[0].stop is None and
u.owner.op.idx_list[0].step is None
):
return [u.owner.inputs[0][-1]]
@register_canonicalize @register_canonicalize
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_IncSubtensor_serialize(node): def local_IncSubtensor_serialize(node):
......
...@@ -1243,6 +1243,48 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -1243,6 +1243,48 @@ class test_local_subtensor_lift(unittest.TestCase):
assert len(prog)==4 assert len(prog)==4
f([[0,1],[2,3]], [4,5]) # let debugmode test something f([[0,1],[2,3]], [4,5]) # let debugmode test something
class test_local_subtensor_merge(unittest.TestCase):
def test_const(self):
# var[const::][-1] -> var[-1]
x = TT.matrix('x')
for idx in [-2,-1,1]:
f = function([x], x[idx::][-1], mode=mode_opt)
#theano.printing.debugprint(f)
topo=f.maker.env.toposort()
assert len(topo)==2
assert isinstance(topo[0].op, TT.Subtensor)
assert isinstance(topo[1].op, theano.compile.function_module.DeepCopyOp)
f([[0,1],[2,3]]) # let debugmode test something
def test_scalar(self):
# var[int::][-1] -> var[-1]
x = TT.matrix('x')
y = TT.iscalar('y')
f = function([x,y], x[y::][-1], mode=mode_opt)
#theano.printing.debugprint(f)
topo=f.maker.env.toposort()
assert len(topo)==2
assert isinstance(topo[0].op, TT.Subtensor)
assert isinstance(topo[1].op, theano.compile.function_module.DeepCopyOp)
for idx in range(-10,2):
f([[0,1],[2,3]], idx) # let debugmode test something
def test_dont_opt(self):
# Test that we don't optimize some case
x = TT.matrix('x')
f = function([x], x[1::][0], mode=mode_opt)
#theano.printing.debugprint(f)
topo=f.maker.env.toposort()
assert len(topo)==3
assert isinstance(topo[0].op, TT.Subtensor)
assert isinstance(topo[1].op, TT.Subtensor)
assert isinstance(topo[2].op, theano.compile.function_module.DeepCopyOp)
f([[0,1],[2,3]]) # let debugmode test something
def test_local_fill_useless(): def test_local_fill_useless():
m = theano.config.mode m = theano.config.mode
if m == 'FAST_COMPILE': if m == 'FAST_COMPILE':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论