提交 028935f7 authored 作者: James Bergstra's avatar James Bergstra

Added log(1+x) -> log1p(x) optimization

上级 611b6c70
...@@ -44,23 +44,32 @@ def _fill_chain(new_out, orig_inputs): ...@@ -44,23 +44,32 @@ def _fill_chain(new_out, orig_inputs):
new_out = T.fill(i, new_out) new_out = T.fill(i, new_out)
return [new_out] return [new_out]
def get_constant_value(v): def get_constant_value(v, fill=False):
"""return the constant value underlying variable `v` """return the constant value underlying variable `v`
If v is the output of dimshuffles, this function digs through them. If v is the output of dimshuffles, fills, this function digs through them.
If `v` is not some view of constant data, then raise a TypeError. If `v` is not some view of constant data, then raise a TypeError.
if fill is True, then it returns (v, [...]) where the second term is a list of variables
that were used in the fill expressions
:note: There may be another function similar to this one in the code, but I'm not sure where it :note: There may be another function similar to this one in the code, but I'm not sure where it
is. is.
""" """
if not isinstance(v, gof.Variable):
return v # why would this happen?
if isinstance(v, gof.Constant): if isinstance(v, gof.Constant):
if fill:
return v.data, []
return v.data return v.data
if v.owner and isinstance(v.owner.op, T.DimShuffle): if v.owner and isinstance(v.owner.op, T.DimShuffle):
return get_constant_value(v.owner.inputs[0]) return get_constant_value(v.owner.inputs[0], fill=fill)
if fill:
if v.owner and v.owner.op == T.fill:
shape, val = v.owner.inputs
# fill(a,b) fills the shape of 'a' filled with 'b'
rval, rshapes = get_constant_value(val, fill=fill)
return rval, rshapes + [shape]
raise TypeError(v) raise TypeError(v)
@gof.optimizer @gof.optimizer
...@@ -1122,6 +1131,30 @@ register_specialize(local_add_specialize) ...@@ -1122,6 +1131,30 @@ register_specialize(local_add_specialize)
mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, local_fill_sink)) mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, local_fill_sink))
@register_specialize
@gof.local_optimizer([T.log])
def local_log1p(node):
# log(1+exp(x)) -> log1p(x)
if node.op == T.log:
log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add:
add_inputs = log_arg.owner.inputs
consts = [0]
fills = []
nonconsts = []
for add_in in add_inputs:
try:
v, f = get_constant_value(add_in, fill=True)
consts.append(v)
fills.extend(f)
except:
nonconsts.append(add_in)
if nonconsts:
if numpy.allclose(numpy.sum(consts), 1):
if len(nonconsts)==1:
return _fill_chain(T.log1p(nonconsts[0]), fills)
else:
return _fill_chain(T.log1p(T.add(*nonconsts)), fills)
def add_calculate(num, denum, aslist = False, out_type=None): def add_calculate(num, denum, aslist = False, out_type=None):
......
...@@ -7,7 +7,7 @@ import theano ...@@ -7,7 +7,7 @@ import theano
from theano import gof from theano import gof
from theano.tensor.opt import * from theano.tensor.opt import *
from theano import tensor from theano import tensor
from theano.tensor import TensorType from theano.tensor import TensorType, inplace
from theano.gof import Env from theano.gof import Env
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano import pprint, shared from theano import pprint, shared
...@@ -904,7 +904,38 @@ class test_fusion(unittest.TestCase): ...@@ -904,7 +904,38 @@ class test_fusion(unittest.TestCase):
# cases[id]=None #to remove g, that link to out that link to the ndarray! # cases[id]=None #to remove g, that link to out that link to the ndarray!
#g.owner.inputs[0] is out... make owner a weakref? #g.owner.inputs[0] is out... make owner a weakref?
def test_log1p():
# check some basic cases
x = dvector()
f = function([x], T.log(1+(x)), mode='FAST_RUN')
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
f = (function([x], T.log(1+(-x))))
assert [node.op for node in f.maker.env.toposort()] == [T.neg, inplace.log1p_inplace]
f = (function([x], -T.log(1+(-x))))
assert [node.op for node in f.maker.env.toposort()] == [T.neg, inplace.log1p_inplace, inplace.neg_inplace]
# check trickier cases (and use different dtype)
y = fmatrix()
f = (function([x,y], T.log(fill(y,1)+(x))))
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f = (function([x,y], T.log(0+(x) + fill(y,1.0) )))
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f = (function([x,y], T.log(2+(x) - fill(y,1.0) )))
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f([1e-7, 10], [[0, 0], [0, 0]]) #debugmode will verify values
# should work for complex
z = zmatrix()
f = function([z], T.log(1+(z)))
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
# should work for int
z = imatrix()
f = function([z], T.log(1+(z)))
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论