提交 e8d9bed5 authored 作者: Jeff Donahue's avatar Jeff Donahue

Add sigmoid_binary_crossentropy function and tests

上级 a4126bcc
...@@ -4,7 +4,8 @@ from .nnet import ( ...@@ -4,7 +4,8 @@ from .nnet import (
CrossentropySoftmax1HotWithBiasDx, CrossentropySoftmaxArgmax1HotWithBias, CrossentropySoftmax1HotWithBiasDx, CrossentropySoftmaxArgmax1HotWithBias,
LogSoftmax, Prepend_scalar_constant_to_each_row, LogSoftmax, Prepend_scalar_constant_to_each_row,
Prepend_scalar_to_each_row, Softmax, Prepend_scalar_to_each_row, Softmax,
SoftmaxGrad, SoftmaxWithBias, binary_crossentropy, SoftmaxGrad, SoftmaxWithBias,
binary_crossentropy, sigmoid_binary_crossentropy,
categorical_crossentropy, crossentropy_categorical_1hot, categorical_crossentropy, crossentropy_categorical_1hot,
crossentropy_categorical_1hot_grad, crossentropy_softmax_1hot, crossentropy_categorical_1hot_grad, crossentropy_softmax_1hot,
crossentropy_softmax_1hot_with_bias, crossentropy_softmax_1hot_with_bias,
......
...@@ -2017,6 +2017,31 @@ def binary_crossentropy(output, target): ...@@ -2017,6 +2017,31 @@ def binary_crossentropy(output, target):
return -(target * tensor.log(output) + (1.0 - target) * tensor.log(1.0 - output)) return -(target * tensor.log(output) + (1.0 - target) * tensor.log(1.0 - output))
def sigmoid_binary_crossentropy(output, target):
"""
Compute the cross-entropy of binary random variables.
`output` should be real-valued (range (-inf, +inf)); `sigmoid` will be
applied to produce a (0, 1) valued input.
`target` is assumed to be probabilities in [0, 1].
Notes
-----
Mathematically equivalent to `binary_crossentropy(sigmoid(output), target)`,
but with more efficient and numerically stable computation.
"""
def grad(inputs, out_grads):
(output, target), (out_grad,) = inputs, out_grads
g_output = out_grad * (sigmoid(output) - target)
g_target = out_grad * (-output)
return [g_output, g_target]
inp = [output, target]
outp = softplus(-abs(output)) + output * ((output > 0) - target)
return theano.OpFromGraph(inp, [outp], grad_overrides=grad, inline=True,
name='sigmoid_binary_crossentropy')(*inp)
def categorical_crossentropy(coding_dist, true_dist): def categorical_crossentropy(coding_dist, true_dist):
""" """
Return the cross-entropy between an approximating distribution and a true Return the cross-entropy between an approximating distribution and a true
......
...@@ -33,6 +33,7 @@ from theano.tensor.nnet import (categorical_crossentropy, ...@@ -33,6 +33,7 @@ from theano.tensor.nnet import (categorical_crossentropy,
h_softmax, h_softmax,
elu, elu,
binary_crossentropy, binary_crossentropy,
sigmoid_binary_crossentropy,
confusion_matrix) confusion_matrix)
from theano.tensor import matrix, vector, lvector, scalar from theano.tensor import matrix, vector, lvector, scalar
from theano.tensor.nnet.nnet import softsign from theano.tensor.nnet.nnet import softsign
...@@ -1771,6 +1772,36 @@ SoftsignTester = makeBroadcastTester( ...@@ -1771,6 +1772,36 @@ SoftsignTester = makeBroadcastTester(
) )
class T_sigmoid_binary_crossentropy(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def _get_test_inputs(self, n=50):
pred, target = numpy.random.randn(2, n).astype(config.floatX)
# apply sigmoid to target, but not pred
return [pred, 1 / (1 + numpy.exp(-target))]
def test_matches_binary_crossentropy(self):
"""
Test sigmoid_binary_crossentropy(p, t) ==
binary_crossentropy(sigmoid(p), t).
"""
pred, target = inputs = tensor.vectors('pt')
reference_val = binary_crossentropy(sigmoid(pred), target)
f_reference = theano.function(inputs, reference_val)
test_val = sigmoid_binary_crossentropy(pred, target)
f_test = theano.function(inputs, test_val)
test_inputs = self._get_test_inputs()
utt.assert_allclose(f_reference(*test_inputs), f_test(*test_inputs))
def test_grad(self):
utt.verify_grad(sigmoid_binary_crossentropy, self._get_test_inputs())
def test_confusion_matrix(): def test_confusion_matrix():
# Defining numpy implementation of confusion matrix # Defining numpy implementation of confusion matrix
def numpy_conf_mat(actual, pred): def numpy_conf_mat(actual, pred):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论