提交 e9c56c39 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3931 from adbrebs/h_softmax_speedup

Speed up h_softmax when full output is requested.
...@@ -2293,14 +2293,8 @@ def h_softmax(x, batch_size, n_outputs, n_classes, n_outputs_per_class, ...@@ -2293,14 +2293,8 @@ def h_softmax(x, batch_size, n_outputs, n_classes, n_outputs_per_class,
if target is None: # Computes the probabilites of all the outputs if target is None: # Computes the probabilites of all the outputs
class_ids = tensor.tile(
tensor.arange(n_classes, dtype="int32")[None, :], (batch_size, 1))
# Second softmax that computes the output probabilities # Second softmax that computes the output probabilities
activations = sparse_block_dot( activations = tensor.tensordot(x, W2, (1, 1)) + b2
W2[None, :, :, :], x[:, None, :],
tensor.zeros((batch_size, 1), dtype='int32'), b2, class_ids)
output_probs = theano.tensor.nnet.softmax( output_probs = theano.tensor.nnet.softmax(
activations.reshape((-1, n_outputs_per_class))) activations.reshape((-1, n_outputs_per_class)))
output_probs = output_probs.reshape((batch_size, n_classes, -1)) output_probs = output_probs.reshape((batch_size, n_classes, -1))
......
...@@ -1614,6 +1614,14 @@ def test_h_softmax(): ...@@ -1614,6 +1614,14 @@ def test_h_softmax():
############# #############
x_mat = numpy.random.normal(size=(batch_size, input_size)).astype(floatX) x_mat = numpy.random.normal(size=(batch_size, input_size)).astype(floatX)
y_mat = numpy.random.randint(0, output_size, batch_size).astype('int32') y_mat = numpy.random.randint(0, output_size, batch_size).astype('int32')
assert(fun_output_tg(x_mat, y_mat).shape == (batch_size,)) tg_output = fun_output_tg(x_mat, y_mat)
assert(fun_output(x_mat).shape == (batch_size, output_size)) all_outputs = fun_output(x_mat)
assert(tg_output.shape == (batch_size,))
assert(all_outputs.shape == (batch_size, output_size))
# Verifies that the outputs computed by fun_output_tg are the same as those
# computed by fun_output.
utt.assert_allclose(
all_outputs[numpy.arange(0, batch_size), y_mat], tg_output)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论