提交 39a5efab authored 作者: AdeB's avatar AdeB

log_sum_exp opt: no more equilibrium so that it is only applied once

上级 b8fffd75
...@@ -6085,8 +6085,6 @@ def local_log_add(node): ...@@ -6085,8 +6085,6 @@ def local_log_add(node):
return [ret] return [ret]
@register_stabilize
@register_specialize
@gof.local_optimizer([T.log]) @gof.local_optimizer([T.log])
def local_log_sum_exp(node): def local_log_sum_exp(node):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max))) # log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
...@@ -6103,22 +6101,12 @@ def local_log_sum_exp(node): ...@@ -6103,22 +6101,12 @@ def local_log_sum_exp(node):
return return
exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis
if not exp_node and exp_node.op != T.exp: if not exp_node or not (
isinstance(exp_node.op, Elemwise) and
isinstance(exp_node.op.scalar_op, scalar.Exp)):
return return
pre_exp = exp_node.inputs[0] pre_exp = exp_node.inputs[0]
# optimisation may have already been applied
if (pre_exp.owner and
isinstance(pre_exp.owner.op, T.Elemwise) and
pre_exp.owner.op.scalar_op == scalar.sub):
max_node = pre_exp.owner.inputs[1].owner
if max_node and isinstance(max_node.op, T.DimShuffle):
max_node = max_node.inputs[0].owner
if not isinstance(max_node.op, T.MaxAndArgmax):
return
if max_node.inputs[0] == pre_exp.owner.inputs[0]:
return
max_pre_keepdims = T.max(pre_exp, axis=axis, keepdims=True) max_pre_keepdims = T.max(pre_exp, axis=axis, keepdims=True)
ret = (max_pre_keepdims + T.log(T.sum(T.exp(pre_exp - max_pre_keepdims), ret = (max_pre_keepdims + T.log(T.sum(T.exp(pre_exp - max_pre_keepdims),
...@@ -6131,6 +6119,11 @@ def local_log_sum_exp(node): ...@@ -6131,6 +6119,11 @@ def local_log_sum_exp(node):
return [ret] return [ret]
compile.optdb.register('local_log_sum_exp',
in2out(local_log_sum_exp, ignore_newtrees=True),
1.6, 'fast_run')
def add_calculate(num, denum, aslist=False, out_type=None): def add_calculate(num, denum, aslist=False, out_type=None):
# TODO: make sure that this function and mul_calculate are similar # TODO: make sure that this function and mul_calculate are similar
if out_type is None: if out_type is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论