Add check to verify that GPU values are equal to CPU values in CTC tests

上级 ef01d10d
......@@ -8,6 +8,8 @@ import theano.tensor as T
from theano.tests import unittest_tools as utt
import theano.gpuarray
from theano.gpuarray.ctc import (ctc_enabled, gpu_ctc)
from theano.tensor.nnet.ctc import ctc
from .config import (mode_with_gpu, mode_without_gpu)
class TestCTC(unittest.TestCase):
......@@ -21,21 +23,35 @@ class TestCTC(unittest.TestCase):
t_activation_times = theano.shared(input_length, name="activation_times")
t_labels = theano.shared(labels, name="labels")
t_cost = gpu_ctc(t_activations, t_labels, t_activation_times)
# Compute CTC costs and gradients on the CPU to compare with GPU
cpu_ctc_cost = ctc(t_activations, t_labels, t_activation_times)
# Symbolic gradient of CTC cost
t_grad = T.grad(T.mean(t_cost), t_activations)
cpu_ctc_grad = T.grad(T.mean(t_cost), t_activations)
# Compile CPU function without optimization
cpu_train = theano.function([], [cpu_ctc_cost, cpu_ctc_grad], mode=mode_without_gpu)
cpu_cost, cpu_grad = cpu_train()
gpu_ctc_cost = gpu_ctc(t_activations, t_labels, t_activation_times)
# Symbolic gradient of CTC cost
gpu_ctc_grad = T.grad(T.mean(t_cost), t_activations)
# Compile symbolic functions
train = theano.function([], [t_cost, t_grad])
gpu_train = theano.function([], [gpu_ctc_cost, gpu_ctc_grad])
cost, grad = train()
gpu_cost, gpu_grad = gpu_train()
# Transfer costs from GPU memory to host
cpu_cost = np.asarray(cost)
cost_from_gpu = np.asarray(gpu_cost)
# Transfer gradients from GPU memory to host
cpu_grad = np.asarray(grad)
grad_from_gpu = np.asarray(gpu_grad)
# Check that results are in conformance with expected values
utt.assert_allclose(expected_grads / gpu_ctc_cost.shape[0], grad_from_gpu)
utt.assert_allclose(expected_costs, cost_from_gpu)
utt.assert_allclose(expected_grads / cost.shape[0], cpu_grad)
utt.assert_allclose(expected_costs, cpu_cost)
# Compare values obtained from CPU and GPU implementations
utt.assert_allclose(cpu_cost, cost_from_gpu)
utt.assert_allclose(cpu_grad, grad_from_gpu)
# Test obtained from Torch tutorial at:
# https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论