提交 028935f7 authored 作者: James Bergstra's avatar James Bergstra

Added log(1+x) -> log1p(x) optimization

上级 611b6c70
......@@ -44,23 +44,32 @@ def _fill_chain(new_out, orig_inputs):
new_out = T.fill(i, new_out)
return [new_out]
def get_constant_value(v):
def get_constant_value(v, fill=False):
"""return the constant value underlying variable `v`
If v is the output of dimshuffles, this function digs through them.
If v is the output of dimshuffles, fills, this function digs through them.
If `v` is not some view of constant data, then raise a TypeError.
if fill is True, then it returns (v, [...]) where the second term is a list of variables
that were used in the fill expressions
:note: There may be another function similar to this one in the code, but I'm not sure where it
is.
"""
if not isinstance(v, gof.Variable):
return v # why would this happen?
if isinstance(v, gof.Constant):
if fill:
return v.data, []
return v.data
if v.owner and isinstance(v.owner.op, T.DimShuffle):
return get_constant_value(v.owner.inputs[0])
return get_constant_value(v.owner.inputs[0], fill=fill)
if fill:
if v.owner and v.owner.op == T.fill:
shape, val = v.owner.inputs
# fill(a,b) fills the shape of 'a' filled with 'b'
rval, rshapes = get_constant_value(val, fill=fill)
return rval, rshapes + [shape]
raise TypeError(v)
@gof.optimizer
......@@ -1122,6 +1131,30 @@ register_specialize(local_add_specialize)
mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, local_fill_sink))
@register_specialize
@gof.local_optimizer([T.log])
def local_log1p(node):
# log(1+exp(x)) -> log1p(x)
if node.op == T.log:
log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add:
add_inputs = log_arg.owner.inputs
consts = [0]
fills = []
nonconsts = []
for add_in in add_inputs:
try:
v, f = get_constant_value(add_in, fill=True)
consts.append(v)
fills.extend(f)
except:
nonconsts.append(add_in)
if nonconsts:
if numpy.allclose(numpy.sum(consts), 1):
if len(nonconsts)==1:
return _fill_chain(T.log1p(nonconsts[0]), fills)
else:
return _fill_chain(T.log1p(T.add(*nonconsts)), fills)
def add_calculate(num, denum, aslist = False, out_type=None):
......
......@@ -7,7 +7,7 @@ import theano
from theano import gof
from theano.tensor.opt import *
from theano import tensor
from theano.tensor import TensorType
from theano.tensor import TensorType, inplace
from theano.gof import Env
from theano.tensor.elemwise import DimShuffle
from theano import pprint, shared
......@@ -904,7 +904,38 @@ class test_fusion(unittest.TestCase):
# cases[id]=None #to remove g, that link to out that link to the ndarray!
#g.owner.inputs[0] is out... make owner a weakref?
def test_log1p():
# check some basic cases
x = dvector()
f = function([x], T.log(1+(x)), mode='FAST_RUN')
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
f = (function([x], T.log(1+(-x))))
assert [node.op for node in f.maker.env.toposort()] == [T.neg, inplace.log1p_inplace]
f = (function([x], -T.log(1+(-x))))
assert [node.op for node in f.maker.env.toposort()] == [T.neg, inplace.log1p_inplace, inplace.neg_inplace]
# check trickier cases (and use different dtype)
y = fmatrix()
f = (function([x,y], T.log(fill(y,1)+(x))))
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f = (function([x,y], T.log(0+(x) + fill(y,1.0) )))
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f = (function([x,y], T.log(2+(x) - fill(y,1.0) )))
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f([1e-7, 10], [[0, 0], [0, 0]]) #debugmode will verify values
# should work for complex
z = zmatrix()
f = function([z], T.log(1+(z)))
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
# should work for int
z = imatrix()
f = function([z], T.log(1+(z)))
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
if __name__ == '__main__':
# unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论