提交 182b798d authored 作者: Frederic's avatar Frederic

make CrossentropySoftmaxArgmax1HotWithBias accept uint* dtype.

上级 f13bed43
...@@ -752,8 +752,9 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -752,8 +752,9 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
or x.type.dtype not in ['float32', 'float64']: or x.type.dtype not in ['float32', 'float64']:
raise ValueError('b must be 1-d tensor of floats', b.type) raise ValueError('b must be 1-d tensor of floats', b.type)
if y_idx.type.ndim != 1 \ if y_idx.type.ndim != 1 \
or y_idx.type.dtype not in ['int8', 'int16', 'int32', 'int64']: or y_idx.type.dtype not in ['int8', 'int16', 'int32', 'int64',
raise ValueError('y_idx must be 1-d tensor of ints', y_idx.type) 'uint8', 'uint16', 'uint32', 'uint64']:
raise ValueError('y_idx must be 1-d tensor of [u]ints', y_idx.type)
# TODO: Is this correct? It used to be y, not y_idx # TODO: Is this correct? It used to be y, not y_idx
nll = tensor.TensorType(x.type.dtype, nll = tensor.TensorType(x.type.dtype,
...@@ -887,10 +888,14 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -887,10 +888,14 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
if ((PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT64) if ((PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT64)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT32) && (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT32)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT16) && (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT16)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT8)) && (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT8)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT64)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT32)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT16)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT8))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"y_idx not int8, int16, int32, or int64"); "y_idx not [u]int8, [u]int16, [u]int32, or [u]int64");
%(fail)s; %(fail)s;
} }
if (PyArray_DIMS(%(x)s)[0] != PyArray_DIMS(%(y_idx)s)[0]) if (PyArray_DIMS(%(x)s)[0] != PyArray_DIMS(%(y_idx)s)[0])
......
...@@ -247,11 +247,15 @@ class T_CrossentropySoftmaxArgmax1HotWithBias(utt.InferShapeTester): ...@@ -247,11 +247,15 @@ class T_CrossentropySoftmaxArgmax1HotWithBias(utt.InferShapeTester):
n_samples = 3 n_samples = 3
# First test gradient when getting a gradient on the NLL output. # First test gradient when getting a gradient on the NLL output.
def grad_on_nll(x, b): def grad_on_nll_dtype(dtype):
return self.op(x, b, y_idx=numpy.random.randint( def grad_on_nll(x, b):
low=0, high=n_classes, size=n_samples))[0] y_idx = numpy.random.randint(low=0, high=n_classes, size=n_samples).astype(dtype)
utt.verify_grad(grad_on_nll, [numpy.random.rand(n_samples, n_classes), return self.op(x, b, y_idx=y_idx)[0]
numpy.random.rand(n_classes)]) return grad_on_nll
for dtype in ['uint8', 'int8', 'uint64', 'int64']:
utt.verify_grad(grad_on_nll_dtype(dtype),
[numpy.random.rand(n_samples, n_classes),
numpy.random.rand(n_classes)])
# Then test gradient when getting a gradient on the softmax output. # Then test gradient when getting a gradient on the softmax output.
def grad_on_softmax(x, b): def grad_on_softmax(x, b):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论