Correct expected cost values in test_torch_case of test_ctc

上级 0b33b90d
...@@ -14,7 +14,7 @@ class TestCTC(unittest.TestCase): ...@@ -14,7 +14,7 @@ class TestCTC(unittest.TestCase):
if not ctc_enabled: if not ctc_enabled:
self.skipTest('Optional library warp-ctc not available') self.skipTest('Optional library warp-ctc not available')
def run_ctc(self, activations, labels, input_length, expected_costs): def run_ctc(self, activations, labels, input_length, expected_costs, expected_grads):
# Check if softmax probabilites are approximately equal to the gradients # Check if softmax probabilites are approximately equal to the gradients
# of the activations, using utt.assert_allclose(a, b) # of the activations, using utt.assert_allclose(a, b)
...@@ -31,9 +31,12 @@ class TestCTC(unittest.TestCase): ...@@ -31,9 +31,12 @@ class TestCTC(unittest.TestCase):
test = theano.function([], [t_cost]) test = theano.function([], [t_cost])
cost, grad = train() cost, grad = train()
cost, = test() test_cost, = test()
utt.assert_allclose(cost, expected_costs) print( grad )
#utt.assert_allclose(grad, expected_grads)
utt.assert_allclose( expected_costs, cost )
# Test obtained from Torch tutorial at: # Test obtained from Torch tutorial at:
# https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md # https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
...@@ -50,7 +53,10 @@ class TestCTC(unittest.TestCase): ...@@ -50,7 +53,10 @@ class TestCTC(unittest.TestCase):
[3, 3], [3, 3],
[2, 3]], dtype=np.int32) [2, 3]], dtype=np.int32)
expected_costs = np.asarray([3.03655, 7.35574, 4.93884], expected_costs = np.asarray([1.609437943, 7.355742931, 4.938849926],
dtype=np.float32) dtype=np.float32)
self.run_ctc(activations, labels, activation_times, expected_costs) grads = [0.200000003, -0.8000000119, 0.200000003, 0.200000003, 0.200000003]
expected_gradients = np.asarray(grads, dtype=np.float32)
self.run_ctc(activations, labels, activation_times, expected_costs, expected_gradients)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论