提交 ea77e606 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix softmax opt to not restrict itself to {float32, float64}.

上级 b8bd441f
...@@ -1746,7 +1746,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1746,7 +1746,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# Check z is zeros_like(log(sm)) # Check z is zeros_like(log(sm))
if not _is_const(z, 0): if not _is_const(z, 0):
return return
if z.type not in (dmatrix, fmatrix): if z.broadcastable != (False, False):
if not (vector_softmax and z.broadcastable == (True, False)): if not (vector_softmax and z.broadcastable == (True, False)):
return return
# here we know that we are incrementing a matrix of zeros # here we know that we are incrementing a matrix of zeros
...@@ -1758,14 +1758,15 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1758,14 +1758,15 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if incr.ndim != 1 or incr.dtype not in tensor.float_dtypes: if incr.ndim != 1 or incr.dtype not in tensor.float_dtypes:
return return
# here we know that we are incrementing some part of matrix z by a vector # here we know that we are incrementing some part of
# matrix z by a vector
# unless the user has taken care to mark that the data and labels have the # unless the user has taken care to mark that the data and
# same number of rows, we cannot be sure here that # labels have the same number of rows, we cannot be sure
# len(y) == len(z) # here that len(y) == len(z) However, in the common case
# However, in the common case that these are predictions and labels it is true. # that these are predictions and labels it is true. We
# We leave it to the Op to crash (and the user to complain) if this assumption is # leave it to the Op to crash (and the user to complain)
# ever not true. # if this assumption is ever not true.
out_grad = -incr out_grad = -incr
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论