Fix setup of symbolic gradients in CTC test

上级 5213ba52
......@@ -26,7 +26,7 @@ class TestCTC(unittest.TestCase):
# Compute CTC costs and gradients on the CPU to compare with GPU
cpu_ctc_cost = ctc(t_activations, t_labels, t_activation_times)
# Symbolic gradient of CTC cost
cpu_ctc_grad = T.grad(T.mean(t_cost), t_activations)
cpu_ctc_grad = T.grad(T.mean(cpu_ctc_cost), t_activations)
# Compile CPU function without optimization
cpu_train = theano.function([], [cpu_ctc_cost, cpu_ctc_grad], mode=mode_without_gpu)
......@@ -34,7 +34,7 @@ class TestCTC(unittest.TestCase):
gpu_ctc_cost = gpu_ctc(t_activations, t_labels, t_activation_times)
# Symbolic gradient of CTC cost
gpu_ctc_grad = T.grad(T.mean(t_cost), t_activations)
gpu_ctc_grad = T.grad(T.mean(gpu_ctc_cost), t_activations)
# Compile symbolic functions
gpu_train = theano.function([], [gpu_ctc_cost, gpu_ctc_grad])
......@@ -46,7 +46,7 @@ class TestCTC(unittest.TestCase):
grad_from_gpu = np.asarray(gpu_grad)
# Check that results are in conformance with expected values
utt.assert_allclose(expected_grads / gpu_ctc_cost.shape[0], grad_from_gpu)
utt.assert_allclose(expected_grads / cost_from_gpu.shape[0], grad_from_gpu)
utt.assert_allclose(expected_costs, cost_from_gpu)
# Compare values obtained from CPU and GPU implementations
......@@ -87,13 +87,13 @@ class TestCTC(unittest.TestCase):
def test_ctc(self):
activations = np.asarray([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]],
dtype='float32')
dtype=np.float32)
activation_times = np.asarray([2, 2], dtype='int32')
activation_times = np.asarray([2, 2], dtype=np.int32)
labels = np.asarray([[1, 2], [1, 2]], dtype='int32')
labels = np.asarray([[1, 2], [1, 2]], dtype=np.int32)
expected_costs = np.asarray([2.962858438, 3.053659201], dtype='float32')
expected_costs = np.asarray([2.962858438, 3.053659201], dtype=np.float32)
grads = [[[0.177031219, -0.7081246376, 0.177031219, 0.177031219, 0.177031219],
[0.177031219, -0.8229685426, 0.291875124, 0.177031219, 0.177031219]],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论