提交 122f8d1e authored 作者: James Bergstra's avatar James Bergstra

fixed nnet.categorical_crossentropy

上级 27a301bb
......@@ -908,7 +908,7 @@ def binary_crossentropy(output, target):
"""
return -(target * tensor.log(output) + (1.0 - target) * tensor.log(1.0 - output))
def categorical_crossentropy(coding_dist, true_dist, axis=1):
def categorical_crossentropy(coding_dist, true_dist):
"""
WARNING: THIS FUNCTION IS UNNECESSARILY POLYMORPHIC.
We ultimately don't want the polymorphism, and will move this function to pylearn.algorithms.cost.
......@@ -939,23 +939,16 @@ def categorical_crossentropy(coding_dist, true_dist, axis=1):
:param axis: the dimension over which each distribution runs. (1 for row distributions, 0
for column distributions)
:rtype: dvector
:rtype: tensor of rank one-less-than `coding_dist`
:returns: the cross entropy between each coding and true distribution.
"""
assert true_dist.ndim in (1,2)
if true_dist.ndim == 2:
return -theano.sum(true_dist * log(coding_dist), axis=axis)
if true_dist.ndim == coding_dist.ndim:
return -theano.sum(true_dist * log(coding_dist), axis=coding_dist.ndim-1)
elif true_dist.ndim == coding_dist.ndim - 1:
return crossentropy_categorical_1hot(coding_dist, true_dist)
else:
if axis == 0:
retval = coding_dist.T
else:
retval = coding_dist,
return categorical_crossentropy_1hot(
#backport
#coding_dist.T if axis == 0 else coding_dist,
retval,
true_dist)
raise TypeError('rank mismatch between coding and true distributions')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论