提交 28e5b5cd authored 作者: James Bergstra's avatar James Bergstra

added log_add optimization

上级 9677a10f
......@@ -1843,6 +1843,27 @@ def local_log1p(node):
else:
return _fill_chain(T.log1p(T.add(*nonconsts)), scalar_inputs)
#TODO: in canonicalize, change log10 and log2 -> log
@register_stabilize
@gof.local_optimizer([T.log])
def local_log_add(node):
# log(exp(x)+exp(y))
#
# Suppose x >= y
# log(exp(x) + exp(y))
# log(exp(x) * (1 + exp(y)/exp(x)))
# x + log(1 + exp(y)/exp(x))
# x + log1p(exp(y)/exp(x))
# x + log1p(exp(y-x))
if node.op == T.log:
z = node.inputs[0]
if z.owner and z.owner.op == T.add:
zi = z.owner.inputs
pre_exp = [ x.owner.inputs[0] for x in zi if x.owner and x.owner.op == T.exp]
if len(pre_exp) == len(zi):
# all arguments to add are exp(<something>)
max_pre = T.maximum(*pre_exp)
return [max_pre + T.log1p(T.exp(T.add(*[p - max_pre for p in pre_exp])))]
def add_calculate(num, denum, aslist = False, out_type=None):
#TODO: make sure that this function and mul_calculate are similar
......
......@@ -930,6 +930,22 @@ def test_log1p():
theano.printing.debugprint(f)
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
def test_log_add():
m = theano.config.mode
if m == 'FAST_COMPILE':
m = 'FAST_RUN'
m = compile.mode.get_mode(m)
m = m.excluding('fusion')
# check some basic cases
x = dvector()
y = dvector()
f = function([x,y], T.log(T.exp(x) + T.exp(y)), mode=m)
theano.printing.debugprint( f)
print f([10000], [10000]) # causes overflow if handled incorrectly
assert numpy.allclose(f([10000], [10000]), 10000+numpy.log1p(1))
class test_local_subtensor_unary(unittest.TestCase):
def test0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论