提交 122f8d1e authored 作者: James Bergstra's avatar James Bergstra

fixed nnet.categorical_crossentropy

上级 27a301bb
...@@ -908,7 +908,7 @@ def binary_crossentropy(output, target): ...@@ -908,7 +908,7 @@ def binary_crossentropy(output, target):
""" """
return -(target * tensor.log(output) + (1.0 - target) * tensor.log(1.0 - output)) return -(target * tensor.log(output) + (1.0 - target) * tensor.log(1.0 - output))
def categorical_crossentropy(coding_dist, true_dist, axis=1): def categorical_crossentropy(coding_dist, true_dist):
""" """
WARNING: THIS FUNCTION IS UNNECESSARILY POLYMORPHIC. WARNING: THIS FUNCTION IS UNNECESSARILY POLYMORPHIC.
We ultimately don't want the polymorphism, and will move this function to pylearn.algorithms.cost. We ultimately don't want the polymorphism, and will move this function to pylearn.algorithms.cost.
...@@ -939,23 +939,16 @@ def categorical_crossentropy(coding_dist, true_dist, axis=1): ...@@ -939,23 +939,16 @@ def categorical_crossentropy(coding_dist, true_dist, axis=1):
:param axis: the dimension over which each distribution runs. (1 for row distributions, 0 :param axis: the dimension over which each distribution runs. (1 for row distributions, 0
for column distributions) for column distributions)
:rtype: dvector :rtype: tensor of rank one-less-than `coding_dist`
:returns: the cross entropy between each coding and true distribution. :returns: the cross entropy between each coding and true distribution.
""" """
assert true_dist.ndim in (1,2) if true_dist.ndim == coding_dist.ndim:
if true_dist.ndim == 2: return -theano.sum(true_dist * log(coding_dist), axis=coding_dist.ndim-1)
return -theano.sum(true_dist * log(coding_dist), axis=axis) elif true_dist.ndim == coding_dist.ndim - 1:
else: return crossentropy_categorical_1hot(coding_dist, true_dist)
if axis == 0:
retval = coding_dist.T
else: else:
retval = coding_dist, raise TypeError('rank mismatch between coding and true distributions')
return categorical_crossentropy_1hot(
#backport
#coding_dist.T if axis == 0 else coding_dist,
retval,
true_dist)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论