提交 f13bed43 authored 作者: Frederic's avatar Frederic

make CrossentropySoftmax1HotWithBiasDx support uint* as class dtype.

上级 9fc78f6a
......@@ -1012,7 +1012,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
return [g_dy, g_sm, g_y_idx]
def c_code_cache_version(self):
return (2,)
return (3,)
def c_code(self, node, name, inp, out, sub):
dnll, sm, y_idx = inp
......@@ -1037,10 +1037,14 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
if ((PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT64)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT32)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT16)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT8))
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT8)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT64)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT32)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT16)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT8))
{
PyErr_SetString(PyExc_TypeError,
"y_idx not int8, int16, int32, or int64");
"y_idx not [u]int8, [u]int16, [u]int32, or [u]int64");
%(fail)s;
}
if ((PyArray_NDIM(%(dnll)s) != 1)
......
......@@ -194,16 +194,20 @@ class T_CrossentropySoftmax1Hot(unittest.TestCase):
class T_CrossentropySoftmax1HotWithBiasDx(utt.InferShapeTester):
def test0(self):
def ff(class_dtype):
def f(sm):
return (theano.tensor.nnet.crossentropy_softmax_1hot_with_bias_dx(
# Class indices
y = numpy.random.randint(low=0, high=5, size=10).astype(class_dtype)
return theano.tensor.nnet.crossentropy_softmax_1hot_with_bias_dx(
numpy.random.rand(10), # Gradient w.r.t. NLL.
sm, # Softmax output.
numpy.random.randint(low=0,
high=5, size=10))) # Class indices.
y)
return f
# Build a random softmax output whose rows sum to 1.
softmax_output = numpy.random.rand(10, 5)
softmax_output /= softmax_output.sum(axis=1).reshape(10, 1)
utt.verify_grad(f, [softmax_output])
for dtype in ['uint8', 'int8', 'uint64', 'int64']:
utt.verify_grad(ff(dtype), [softmax_output])
def test1(self):
rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论