Enable gradients tests in CTC GPU wrapper tests

上级 b0c0476c
......@@ -5,11 +5,11 @@ import numpy as np
import theano
import theano.tensor as T
from theano import config
from theano.tests import unittest_tools as utt
import theano.gpuarray
from theano.gpuarray.basic_ops import infer_context_name
from theano.gpuarray.ctc import (ctc_enabled, ctc)
from theano.gpuarray.basic_ops import gpu_contiguous
class TestCTC(unittest.TestCase):
def setUp(self):
......@@ -17,9 +17,6 @@ class TestCTC(unittest.TestCase):
self.skipTest('Optional library warp-ctc not available')
def run_ctc(self, activations, labels, input_length, expected_costs, expected_grads):
# Check if softmax probabilites are approximately equal to the gradients
# of the activations, using utt.assert_allclose(a, b)
# Create symbolic variables
t_activations = theano.shared(activations, name="activations")
t_activation_times = theano.shared(input_length, name="activation_times")
......@@ -29,20 +26,23 @@ class TestCTC(unittest.TestCase):
# Symbolic gradient of CTC cost
t_grad = T.grad(T.mean(t_cost), t_activations)
# Compile symbolic functions
#train = theano.function([], [t_cost, t_grad])
train = theano.function([], [t_cost, t_grad])
test = theano.function([], [t_cost])
#cost, grad = train()
cost, grad = train()
cost, = test()
cpu_cost = np.empty(shape=cost.shape, dtype=np.float32)
# Transfer costs from GPU memory to host
cost.read(cpu_cost)
cost.sync()
#cpu_grad = np.empty(shape=grad.shape, dtype=np.float32)
#grad.read(cpu_grad)
cpu_grad = np.empty(shape=grad.shape, dtype=np.float32)
# Transfer gradients from GPU memory to host
grad.read(cpu_grad)
grad.sync()
#utt.assert_allclose(expected_grads, grad)
utt.assert_allclose(expected_grads / cost.shape[0], cpu_grad)
utt.assert_allclose(expected_costs, cpu_cost)
# Test obtained from Torch tutorial at:
......@@ -76,8 +76,7 @@ class TestCTC(unittest.TestCase):
self.run_ctc(activations, labels, activation_times, expected_costs, expected_gradients)
def simple_test(self):
def test_ctc(self):
activations = np.asarray([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]],
dtype='float32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论