提交 e1528269 authored 作者: AdeB's avatar AdeB

Log sum exp optimization for numerical stability

上级 da95bf92
...@@ -6085,6 +6085,52 @@ def local_log_add(node): ...@@ -6085,6 +6085,52 @@ def local_log_add(node):
return [ret] return [ret]
@register_stabilize
@register_specialize
@gof.local_optimizer([T.log])
def local_log_sum_exp(node):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
if node.op != T.log:
return
sum_node = node.inputs[0].owner
# If the sum has keepdims=True, there might be a dimshuffle
if sum_node and isinstance(sum_node.op, T.DimShuffle):
sum_node = sum_node.inputs[0].owner
if not sum_node or not isinstance(sum_node.op, T.Sum):
return
exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis
if not exp_node and exp_node.op != T.exp:
return
pre_exp = exp_node.inputs[0]
# optimisation may have already been applied
if (pre_exp.owner and
isinstance(pre_exp.owner.op, T.Elemwise) and
pre_exp.owner.op.scalar_op == scalar.sub):
max_node = pre_exp.owner.inputs[1].owner
if max_node and isinstance(max_node.op, T.DimShuffle):
max_node = max_node.inputs[0].owner
if not isinstance(max_node.op, T.MaxAndArgmax):
return
if max_node.inputs[0] == pre_exp.owner.inputs[0]:
return
max_pre_keepdims = T.max(pre_exp, axis=axis, keepdims=True)
ret = (max_pre_keepdims + T.log(T.sum(T.exp(pre_exp - max_pre_keepdims),
axis=axis, keepdims=True)))
# Restore shape and broadcastable pattern
ret = T.reshape(ret, node.inputs[0].shape)
ret = T.patternbroadcast(ret, node.inputs[0].broadcastable)
return [ret]
def add_calculate(num, denum, aslist=False, out_type=None): def add_calculate(num, denum, aslist=False, out_type=None):
# TODO: make sure that this function and mul_calculate are similar # TODO: make sure that this function and mul_calculate are similar
if out_type is None: if out_type is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论