Add test to check if gradient computations are disabled in CTC test

上级 7c44bf4c
......@@ -6,7 +6,7 @@ import numpy as np
import theano
import theano.tensor as T
from theano.tests import unittest_tools as utt
from theano.tensor.nnet.ctc import (ctc_enabled, ctc)
from theano.tensor.nnet.ctc import (ctc_enabled, ctc, ConnectionistTemporalClassification)
class TestCTC(unittest.TestCase):
......@@ -38,6 +38,18 @@ class TestCTC(unittest.TestCase):
utt.assert_allclose(expected_grads / cost.shape[0], grad)
utt.assert_allclose(expected_costs, cost)
self.check_grads_disabled(t_activations, t_labels, t_activation_times)
def check_grads_disabled(self, activations, labels, input_length):
"""
Check if optimization to disable gradients is working
"""
ctc_cost = ctc(activations, labels, input_length)
ctc_function = theano.function([], [ctc_cost])
for node in ctc_function.maker.fgraph.apply_nodes:
if isinstance(node.op, ConnectionistTemporalClassification):
assert (node.op.compute_grad is False)
# Test obtained from Torch tutorial at:
# https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
def test_torch_case(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论