提交 65898f85 authored 作者: fvisin's avatar fvisin 提交者: Francesco Visin

Remove unused optimization for log softmax

上级 f05a0c89
......@@ -30,7 +30,6 @@ from theano.tensor.nnet.sigm import sigmoid, softplus
from theano.gradient import DisconnectedType
from theano.gradient import grad_not_implemented
from theano.tensor.nnet.blocksparse import sparse_block_dot
from theano.tensor.type import values_eq_approx_remove_nan
############
......@@ -609,7 +608,6 @@ class LogSoftmax(gof.Op):
activation function gets applied row-wise.
"""
def make_node(self, x):
x = tensor.as_tensor_variable(x)
if x.type.ndim not in (1, 2) \
......@@ -646,6 +644,25 @@ class LogSoftmax(gof.Op):
logsoftmax_op = LogSoftmax()
@opt.register_specialize('stabilize')
@gof.local_optimizer([tensor.Elemwise])
def local_logsoftmax(node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
Note: only forward pass is affected
"""
if (isinstance(node.op, tensor.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Log) and
len(node.inputs) == 1 and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Softmax)):
# what is --> and len(node.inputs[0].owner.out.clients) == 1):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax()
return [new_op(inVars)]
def softmax_graph(c):
return tensor.exp(c) / tensor.exp(c).sum(axis=-1, keepdims=True)
......@@ -2033,27 +2050,6 @@ prepend_0_to_each_row = Prepend_scalar_constant_to_each_row(0.)
prepend_1_to_each_row = Prepend_scalar_constant_to_each_row(1.)
# numerically stabilize log softmax (X)
# as X-X.max(axis=1).dimshuffle(0,'x') - log(exp(X-X.max(axis=1).dimshuffle(0,'x')).sum(axis=1)).dimshuffle(0,'x)
def make_out_pattern(X):
stabilized_X = X - X.max(axis=1).dimshuffle(0, 'x')
out_var = stabilized_X - tensor.log(tensor.exp(stabilized_X).sum(
axis=1)).dimshuffle(0, 'x')
# tell DEBUG_MODE that it's OK if the original graph produced NaN and the optimized graph does not
out_var.values_eq_approx = values_eq_approx_remove_nan
return out_var
local_log_softmax = gof.PatternSub(in_pattern=(tensor.log, (softmax_op, 'x')),
out_pattern=(make_out_pattern, 'x'),
allow_multiple_clients=True)
# don't do register_stabilize, this is to make local_log_softmax run
# only after another more specific optimization that stabilizes cross entropy
# opt.register_stabilize(local_log_softmax, name = 'local_log_softmax')
opt.register_specialize(local_log_softmax, 'fast_compile_gpu', name='local_log_softmax')
def relu(x, alpha=0):
"""
Compute the element-wise rectified linear activation function.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论