提交 7eb745d8 authored 作者: Frederic's avatar Frederic

Add opt that lift flatten toward the inputs for unary elemwise.

This enable elemwise opt with flatten in the middle. This can also enable more elemwise fusion that got split by flatten.
上级 3fde8039
...@@ -360,7 +360,7 @@ class T_softplus_opts(unittest.TestCase): ...@@ -360,7 +360,7 @@ class T_softplus_opts(unittest.TestCase):
f(numpy.random.rand(54).astype(config.floatX)) f(numpy.random.rand(54).astype(config.floatX))
def test_log1msigm_to_softplus(self): def test_log1msigm_to_softplus(self):
x = T.vector() x = T.matrix()
out = T.log(1 - sigmoid(x)) out = T.log(1 - sigmoid(x))
f = theano.function([x], out, mode=self.m) f = theano.function([x], out, mode=self.m)
...@@ -369,7 +369,18 @@ class T_softplus_opts(unittest.TestCase): ...@@ -369,7 +369,18 @@ class T_softplus_opts(unittest.TestCase):
assert isinstance(topo[0].op.scalar_op, assert isinstance(topo[0].op.scalar_op,
theano.tensor.nnet.sigm.ScalarSoftplus) theano.tensor.nnet.sigm.ScalarSoftplus)
assert isinstance(topo[1].op.scalar_op, theano.scalar.Neg) assert isinstance(topo[1].op.scalar_op, theano.scalar.Neg)
f(numpy.random.rand(54).astype(config.floatX)) f(numpy.random.rand(54, 11).astype(config.floatX))
# Same test with a flatten
out = T.log(1 - T.flatten(sigmoid(x)))
f = theano.function([x], out, mode=self.m)
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert isinstance(topo[0].op, T.Flatten)
assert isinstance(topo[1].op.scalar_op,
theano.tensor.nnet.sigm.ScalarSoftplus)
assert isinstance(topo[2].op.scalar_op, theano.scalar.Neg)
f(numpy.random.rand(54, 11).astype(config.floatX))
def test_log1pexp_to_softplus(self): def test_log1pexp_to_softplus(self):
m = theano.config.mode m = theano.config.mode
......
...@@ -2385,6 +2385,27 @@ def local_div_switch_sink(node): ...@@ -2385,6 +2385,27 @@ def local_div_switch_sink(node):
return False return False
################
# Flatten Opts #
################
@register_canonicalize
@register_stabilize
@gof.local_optimizer([])
def local_flatten_lift(node):
"""
Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x))
This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
"""
if (isinstance(node.op, T.Flatten) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1):
f = node.op(node.inputs[0].owner.inputs[0])
e = node.inputs[0].owner.op(f)
return [e]
################## ##################
# Reshape opts # # Reshape opts #
################## ##################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论