Add Torch test case to CTC GPU wrapper tests

上级 3be5ca7b
......@@ -36,6 +36,7 @@ class TestCTC(unittest.TestCase):
cost, = test()
cpu_cost = np.empty(shape=cost.shape, dtype=np.float32)
# Transfer costs from GPU memory to host
cost.read(cpu_cost)
#cpu_grad = np.empty(shape=grad.shape, dtype=np.float32)
......@@ -44,6 +45,38 @@ class TestCTC(unittest.TestCase):
#utt.assert_allclose(expected_grads, grad)
utt.assert_allclose(expected_costs, cpu_cost)
# Test obtained from Torch tutorial at:
# https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
def test_torch_case(self):
# Layout, from slowest to fastest changing dimension, is (time, batchSize, inputLayerSize)
activations = np.asarray([[[0, 0, 0, 0, 0], [1, 2, 3, 4, 5], [-5, -4, -3, -2, -1]],
[[0, 0, 0, 0, 0], [6, 7, 8, 9, 10], [-10, -9, -8, -7, -6]],
[[0, 0, 0, 0, 0], [11, 12, 13, 14, 15], [-15, -14, -13, -12, -11]]],
dtype=np.float32)
# Duration of each sequence
activation_times = np.asarray([1, 3, 3], dtype=np.int32)
# Labels for each sequence
labels = np.asarray([[1, -1],
[3, 3],
[2, 3]], dtype=np.int32)
expected_costs = np.asarray([1.609437943, 7.355742931, 4.938849926],
dtype=np.float32)
grads = [[[0.2, -0.8, 0.2, 0.2, 0.2],
[0.01165623125, 0.03168492019, 0.08612854034, -0.7658783197, 0.636408627],
[-0.02115798369, 0.03168492019, -0.8810571432, 0.2341216654, 0.636408627]],
[[0, 0, 0, 0, 0],
[-0.9883437753, 0.03168492019, 0.08612854034, 0.2341216654, 0.636408627],
[-0.02115798369, 0.03168492019, -0.1891518533, -0.4577836394, 0.636408627]],
[[0, 0, 0, 0, 0],
[0.01165623125, 0.03168492019, 0.08612854034, -0.7658783197, 0.636408627],
[-0.02115798369, 0.03168492019, 0.08612854034, -0.7330639958, 0.636408627]]]
expected_gradients = np.asarray(grads, dtype=np.float32)
self.run_ctc(activations, labels, activation_times, expected_costs, expected_gradients)
def simple_test(self):
activations = np.asarray([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论