提交 58ddeb48 authored 作者: Frederic's avatar Frederic

Update following code review.

上级 1f9e8f47
......@@ -20,7 +20,7 @@ since 2007. But it is also approachable enough to be used in the classroom
News
====
* New technical report on Theano: `Theano: new features and speed improvements <http://arxiv.org/abs/1211.5590>`_. Please cite the other paper bellow.
* New technical report on Theano: `Theano: new features and speed improvements <http://arxiv.org/abs/1211.5590>`_. Please cite the other paper below.
* Theano 0.6rc2 was released. Everybody is encouraged to update.
......
......@@ -752,8 +752,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
or x.type.dtype not in ['float32', 'float64']:
raise ValueError('b must be 1-d tensor of floats', b.type)
if y_idx.type.ndim != 1 \
or y_idx.type.dtype not in ['int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64']:
or y_idx.type.dtype not in tensor.discrete_dtypes:
raise ValueError('y_idx must be 1-d tensor of [u]ints', y_idx.type)
# TODO: Is this correct? It used to be y, not y_idx
......@@ -885,19 +884,6 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
PyErr_SetString(PyExc_ValueError, "y_idx not 1d tensor");
%(fail)s;
}
if ((PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT64)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT32)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT16)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT8)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT64)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT32)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT16)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT8))
{
PyErr_SetString(PyExc_TypeError,
"y_idx not [u]int8, [u]int16, [u]int32, or [u]int64");
%(fail)s;
}
if (PyArray_DIMS(%(x)s)[0] != PyArray_DIMS(%(y_idx)s)[0])
{
PyErr_Format(PyExc_ValueError,
......@@ -987,6 +973,15 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
dy = tensor.as_tensor_variable(dy)
sm = tensor.as_tensor_variable(sm)
y_idx = tensor.as_tensor_variable(y_idx)
if (dy.type.ndim != 1 or
dy.type.dtype not in ['float32', 'float64']):
raise ValueError('dy must be 1-d tensor of floats', dy.type)
if (sm.type.ndim != 2 or
sm.type.dtype not in ['float32', 'float64']):
raise ValueError('sm must be 1-d tensor of floats', sm.type)
if (y_idx.type.ndim != 1 or
y_idx.type.dtype not in tensor.discrete_dtypes):
raise ValueError('y_idx must be 1-d tensor of [u]ints', y_idx.type)
return Apply(self, [dy, sm, y_idx], [sm.type.make_variable()])
def perform(self, node, input_storage, output_storage):
......@@ -1039,19 +1034,6 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
"sm type should be float32 or float64");
%(fail)s;
}
if ((PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT64)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT32)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT16)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT8)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT64)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT32)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT16)
&& (PyArray_DESCR(%(y_idx)s)->type_num != NPY_UINT8))
{
PyErr_SetString(PyExc_TypeError,
"y_idx not [u]int8, [u]int16, [u]int32, or [u]int64");
%(fail)s;
}
if ((PyArray_NDIM(%(dnll)s) != 1)
|| (PyArray_NDIM(%(sm)s) != 2)
|| (PyArray_NDIM(%(y_idx)s) != 1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论