Correct expected cost values in test_torch_case of test_ctc

上级 0b33b90d
......@@ -14,7 +14,7 @@ class TestCTC(unittest.TestCase):
if not ctc_enabled:
self.skipTest('Optional library warp-ctc not available')
def run_ctc(self, activations, labels, input_length, expected_costs):
def run_ctc(self, activations, labels, input_length, expected_costs, expected_grads):
# Check if softmax probabilites are approximately equal to the gradients
# of the activations, using utt.assert_allclose(a, b)
......@@ -31,9 +31,12 @@ class TestCTC(unittest.TestCase):
test = theano.function([], [t_cost])
cost, grad = train()
cost, = test()
test_cost, = test()
utt.assert_allclose(cost, expected_costs)
print( grad )
#utt.assert_allclose(grad, expected_grads)
utt.assert_allclose( expected_costs, cost )
# Test obtained from Torch tutorial at:
# https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
......@@ -50,7 +53,10 @@ class TestCTC(unittest.TestCase):
[3, 3],
[2, 3]], dtype=np.int32)
expected_costs = np.asarray([3.03655, 7.35574, 4.93884],
expected_costs = np.asarray([1.609437943, 7.355742931, 4.938849926],
dtype=np.float32)
self.run_ctc(activations, labels, activation_times, expected_costs)
grads = [0.200000003, -0.8000000119, 0.200000003, 0.200000003, 0.200000003]
expected_gradients = np.asarray(grads, dtype=np.float32)
self.run_ctc(activations, labels, activation_times, expected_costs, expected_gradients)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论