提交 a51c2b88 authored 作者: Frederic Bastien's avatar Frederic Bastien

lift subtensor taked on elemwise in more case. Test it. Don't lift if the…

lift subtensor taked on elemwise in more case. Test it. Don't lift if the elemwise is used by other computation.
上级 f8927951
......@@ -1035,17 +1035,45 @@ def local_upcast_elemwise_constant_inputs(node):
@register_canonicalize
@gof.local_optimizer([])
def local_subtensor_unary(node):
def local_subtensor_lift(node):
"""
unary(x)[idx] -> unary(x[idx])
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
"""
if isinstance(node.op, T.Subtensor):
u = node.inputs[0]
if u.owner and isinstance(u.owner.op, T.Elemwise) and len(u.owner.inputs)==1:
if not u.owner or len(u.clients) > 1:
return False
if isinstance(u.owner.op, T.Elemwise) and len(u.owner.inputs)==1:
idx = node.inputs[1:]
x_idx = node.op(u.owner.inputs[0], *idx)
return [u.owner.op(x_idx)]
if isinstance(u.owner.op, T.Elemwise):
new_inputs = []
if all([sum(i.type.broadcastable)==0 for i in u.owner.inputs]):
# There is no broadcastable in the inputs
idx = node.inputs[1:]
new_inputs=[node.op(i, *idx) for i in u.owner.inputs]
return [u.owner.op(*new_inputs)]
elif all([sum(i.type.broadcastable) in [i.ndim,0] for i in u.owner.inputs]):
# There is no broadcastable in the inputs or it is scalar
idx = node.inputs[1:]
new_inputs = []
for i in u.owner.inputs:
if sum(i.type.broadcastable) == 0:
new_inputs.append(node.op(i, *idx))
else:
# If the subtensor remove some dims, we must
# lower the number of dimensions of this scalar.
if node.outputs[0].ndim == i.ndim:
new_inputs.append(i)
else:
new_inputs.append(i.dimshuffle('x'*node.outputs[0].ndim))
return [u.owner.op(*new_inputs)]
@register_canonicalize
@gof.local_optimizer([None])
def local_IncSubtensor_serialize(node):
......
......@@ -1130,7 +1130,7 @@ def test_log_add():
#TODO: (write and) test that the optimization works with Sum in addition to working with Add.
class test_local_subtensor_unary(unittest.TestCase):
class test_local_subtensor_lift(unittest.TestCase):
def test0(self):
# basic test that the Op works
......@@ -1143,10 +1143,66 @@ class test_local_subtensor_unary(unittest.TestCase):
prog=f.maker.env.toposort()
assert isinstance(prog[0].op, TT.Subtensor) #first subtensor
assert prog[1].op == TT.exp
assert len(prog)==2
f([[0,1],[2,3]]) # let debugmode test something
def test0b(self):
# as test0, but we reuse the output of the elemwise
# So we should not lift the subtensor
x = TT.matrix('x')
f = function([x], [TT.exp(x)[0], TT.exp(x)], mode=mode_opt)
prog=f.maker.env.toposort()
assert prog[0].op == TT.exp
assert isinstance(prog[1].op, TT.Subtensor) #first subtensor
assert isinstance(prog[2].op, theano.compile.function_module.DeepCopyOp)
assert len(prog)==3
f([[0,1],[2,3]]) # let debugmode test something
def test1(self):
# basic test that the optimization work with scalar broadcasted
x = TT.matrix('x')
y = TT.scalar('y')
z = TT.matrix('z')
f = function([x,y,z], TT.exp(x+y+z)[0], mode=mode_opt)
prog=f.maker.env.toposort()
assert isinstance(prog[1].op, TT.DimShuffle)
assert isinstance(prog[0].op, TT.Subtensor) #first subtensor
assert isinstance(prog[2].op, TT.Subtensor) #first subtensor
assert isinstance(prog[3].op.scalar_op, theano.scalar.Composite)#Composite{add,add}
assert len(prog)==4
f([[0,1],[2,3]], 4, [[4,5],[6,7]]) # let debugmode test something
def test2(self):
# as 1, but take a slice
x = TT.matrix('x')
y = TT.scalar('y')
z = TT.matrix('z')
f = function([x,y,z], TT.exp(x+y+z)[0:2], mode=mode_opt)
prog=f.maker.env.toposort()
assert isinstance(prog[1].op, TT.DimShuffle)
assert isinstance(prog[0].op, TT.Subtensor) #first subtensor
assert isinstance(prog[2].op, TT.Subtensor) #first subtensor
assert isinstance(prog[3].op.scalar_op, theano.scalar.Composite)#Composite{add,add}
assert len(prog)==4
f([[0,1],[2,3]], 4, [[4,5],[6,7]]) # let debugmode test something
def test3(self):
# basic test that the optimization does work with broadcasting
# for unary elemwise.
y = TT.vector('y')
f = function([y], TT.exp(y.dimshuffle(0,'x'))[0], mode=mode_opt)
prog=f.maker.env.toposort()
assert isinstance(prog[0].op, TT.DimShuffle)
assert isinstance(prog[1].op, TT.Subtensor)
assert prog[2].op == TT.exp
assert len(prog)==3
f([4,5]) # let debugmode test something
def test4(self):
# basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to,
# ... but right now it doesn't, so it shouldn't try.
......@@ -1163,7 +1219,22 @@ class test_local_subtensor_unary(unittest.TestCase):
assert prog[1].op == TT.add
assert isinstance(prog[2].op, TT.Subtensor) #first subtensor
assert prog[3].op == inplace.exp_inplace
assert len(prog)==4
f([[0,1],[2,3]], [4,5]) # let debugmode test something
def test5(self):
# test that we don't lift when we reuse the output of the
# elemwise for other computation.
x = TT.matrix('x')
y = TT.vector('y')
f = function([x,y], [TT.exp(x+y)[0],TT.exp(x+y)+x], mode=mode_opt)
prog=f.maker.env.toposort()
assert isinstance(prog[0].op, TT.DimShuffle)
assert isinstance(prog[1].op.scalar_op, theano.scalar.Composite)#Composite{add,exp}
assert prog[2].op == TT.add
assert isinstance(prog[3].op, TT.Subtensor) #first subtensor
assert len(prog)==4
f([[0,1],[2,3]], [4,5]) # let debugmode test something
def test_local_fill_useless():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论