Add simple test to CTC wrapper tests

上级 e8b3fb57
......@@ -10,6 +10,13 @@ from theano.tensor.nnet.ctc import (ctc_enabled, ctc)
class TestCTC(unittest.TestCase):
"""
Test Baidu CTC wrapper implementation.
Expected values for costs and gradients are obtained through an external
C implementation, that uses the library directly.
"""
def setUp(self):
if not ctc_enabled:
self.skipTest('Optional library warp-ctc not available')
......@@ -66,3 +73,23 @@ class TestCTC(unittest.TestCase):
expected_gradients = np.asarray(grads, dtype=np.float32)
self.run_ctc(activations, labels, activation_times, expected_costs, expected_gradients)
def simple_test(self):
activations = np.asarray([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]],
dtype=np.float32)
activation_times = np.asarray([2, 2], dtype=np.int32)
labels = np.asarray([[1, 2], [1, 2]], dtype=np.int32)
expected_costs = np.asarray([2.962858438, 3.053659201], dtype=np.float32)
grads = [[[0.177031219, -0.7081246376, 0.177031219, 0.177031219, 0.177031219],
[0.177031219, -0.8229685426, 0.291875124, 0.177031219, 0.177031219]],
[[0.291875124, 0.177031219, -0.8229685426, 0.177031219, 0.177031219],
[0.1786672771, 0.1786672771, -0.7334594727, 0.1974578798, 0.1786672771]]]
expected_gradients = np.asarray(grads, dtype=np.float32)
self.run_ctc(activations, labels, activation_times, expected_costs, expected_gradients)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论