提交 973c2669 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

marked log softmax optimization as being allowed to remove NaN

上级 ea55086c
...@@ -1735,7 +1735,11 @@ prepend_1_to_each_row = Prepend_scalar_constant_to_each_row(1.) ...@@ -1735,7 +1735,11 @@ prepend_1_to_each_row = Prepend_scalar_constant_to_each_row(1.)
# as X-X.max(axis=1).dimshuffle(0,'x') - log(exp(X-X.max(axis=1).dimshuffle(0,'x')).sum(axis=1)).dimshuffle(0,'x) # as X-X.max(axis=1).dimshuffle(0,'x') - log(exp(X-X.max(axis=1).dimshuffle(0,'x')).sum(axis=1)).dimshuffle(0,'x)
def make_out_pattern(X): def make_out_pattern(X):
stabilized_X = X - X.max(axis=1).dimshuffle(0,'x') stabilized_X = X - X.max(axis=1).dimshuffle(0,'x')
return stabilized_X - tensor.log(tensor.exp(stabilized_X).sum(axis=1)).dimshuffle(0,'x') out_var = stabilized_X - tensor.log(tensor.exp(stabilized_X).sum(axis=1)).dimshuffle(0,'x')
#tell DEBUG_MODE that it's OK if the original graph produced NaN and the optimized graph does not
out_var.values_eq_approx = out_var.type.values_eq_approx_remove_nan
return out_var
local_log_softmax = gof.PatternSub( in_pattern = (tensor.log, (softmax, 'x')), local_log_softmax = gof.PatternSub( in_pattern = (tensor.log, (softmax, 'x')),
out_pattern = (make_out_pattern, 'x'), out_pattern = (make_out_pattern, 'x'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论