提交 2b42af36 authored 作者: Frederic's avatar Frederic

Add theano.tensor.extra_ops.to_one_hot

上级 816d0bbc
...@@ -886,3 +886,23 @@ def fill_diagonal_offset(a, val, offset): ...@@ -886,3 +886,23 @@ def fill_diagonal_offset(a, val, offset):
is filled with scalar 'val'. The output is unwrapped. is filled with scalar 'val'. The output is unwrapped.
""" """
return fill_diagonal_offset_(a, val, offset) return fill_diagonal_offset_(a, val, offset)
def to_one_hot(y, nb_class, dtype=None):
"""Return a matrix where each row correspond to the one hot
encoding of each element in y.
:param y: A vector of integer value between 0 and nb_class - 1.
:param nb_class: The number of class in y.
:param dtype: The dtype of the returned matrix. Default floatX.
:return: A matrix of shape (y.shape[0], nb_class), where each
row ``i`` is the one hot encoding of the corresponding ``y[i]``
value.
"""
ret = theano.tensor.zeros((y.shape[0], nb_class),
dtype=dtype)
ret = theano.tensor.set_subtensor(ret[theano.tensor.arange(y.shape[0]), y],
1)
return ret
...@@ -7,9 +7,11 @@ from theano.tests import unittest_tools as utt ...@@ -7,9 +7,11 @@ from theano.tests import unittest_tools as utt
from theano.tensor.extra_ops import (CumsumOp, cumsum, CumprodOp, cumprod, from theano.tensor.extra_ops import (CumsumOp, cumsum, CumprodOp, cumprod,
BinCountOp, bincount, DiffOp, diff, BinCountOp, bincount, DiffOp, diff,
squeeze, RepeatOp, repeat, Bartlett, bartlett, squeeze, RepeatOp, repeat,
FillDiagonal, fill_diagonal, FillDiagonalOffset, Bartlett, bartlett,
fill_diagonal_offset) FillDiagonal, fill_diagonal,
FillDiagonalOffset, fill_diagonal_offset,
to_one_hot)
from theano import tensor as T from theano import tensor as T
from theano import config, tensor, function from theano import config, tensor, function
...@@ -529,3 +531,30 @@ class TestFillDiagonalOffset(utt.InferShapeTester): ...@@ -529,3 +531,30 @@ class TestFillDiagonalOffset(utt.InferShapeTester):
test_offset], test_offset],
self.op_class ) self.op_class )
def test_to_one_hot():
v = theano.tensor.ivector()
o = to_one_hot(v, 10)
f = theano.function([v], o)
out = f([1, 2, 3, 5, 6])
assert out.dtype == theano.config.floatX
assert numpy.allclose(
out,
[[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]])
v = theano.tensor.ivector()
o = to_one_hot(v, 10, dtype="int32")
f = theano.function([v], o)
out = f([1, 2, 3, 5, 6])
assert out.dtype == "int32"
assert numpy.allclose(
out,
[[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论