提交 1c26e1a9 authored 作者: Frederic's avatar Frederic

Add opt local_reshape_lift to enable softplus opt with reshape in the middle.

上级 1e2d5bca
......@@ -382,6 +382,17 @@ class T_softplus_opts(unittest.TestCase):
assert isinstance(topo[2].op.scalar_op, theano.scalar.Neg)
f(numpy.random.rand(54, 11).astype(config.floatX))
# Same test with a reshape
out = T.log(1 - sigmoid(x).reshape([x.size]))
f = theano.function([x], out, mode=self.m)
topo = f.maker.fgraph.toposort()
#assert len(topo) == 3
assert any(isinstance(node.op, T.Reshape) for node in topo)
assert any(isinstance(getattr(node.op, 'scalar_op', None),
theano.tensor.nnet.sigm.ScalarSoftplus)
for node in topo)
f(numpy.random.rand(54, 11).astype(config.floatX))
def test_log1pexp_to_softplus(self):
m = theano.config.mode
if m == 'FAST_COMPILE':
......
......@@ -2436,6 +2436,26 @@ def local_reshape_chain(node):
return False
register_canonicalize(local_reshape_chain)
@register_canonicalize
@register_stabilize
@gof.local_optimizer([])
def local_reshape_lift(node):
"""
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape.
"""
if (isinstance(node.op, T.Reshape) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1):
r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
e = node.inputs[0].owner.op(r)
return [e]
if 0:
# TODO: Test that this optimziation works.
@register_canonicalize
......
......@@ -4004,6 +4004,19 @@ def test_local_flatten_lift():
assert isinstance(topo[1].op, tensor.Elemwise)
def test_local_reshape_lift():
x = tensor.tensor4()
out = T.exp(x).reshape([x.size])
assert out.ndim == 1
mode = compile.mode.get_default_mode()
mode = mode.including('local_reshape_lift')
f = theano.function([x], out, mode=mode)
f(numpy.random.rand(5, 4, 3, 2).astype(config.floatX))
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-2].op, tensor.Reshape)
assert isinstance(topo[-1].op, tensor.Elemwise)
class Test_lift_transpose_through_dot(unittest.TestCase):
def simple_optimize(self, g):
out2in(opt.local_useless_elemwise).optimize(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论