提交 fc12422d authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix CTC tests in FAST_COMPILE

上级 3e86efec
......@@ -163,7 +163,7 @@ def gpu_ctc(activations, labels, input_lengths):
# Disable gradient computation if not needed
@register_canonicalize
@register_canonicalize("fast_compile")
@local_optimizer([GpuConnectionistTemporalClassification])
def local_gpu_ctc_no_grad(node):
if isinstance(node.op, GpuConnectionistTemporalClassification):
......
......@@ -49,7 +49,7 @@ class TestCTC(unittest.TestCase):
# Symbolic gradient of CTC cost
gpu_ctc_grad = T.grad(T.mean(gpu_ctc_cost), activations)
outputs += [gpu_ctc_grad]
return theano.function([], outputs)
return theano.function([], outputs, mode=mode_with_gpu)
def check_expected_values(self, activations, labels, input_length, expected_costs, expected_grads):
gpu_train = self.setup_gpu_op(activations, labels, input_length)
......@@ -139,4 +139,4 @@ class TestCTC(unittest.TestCase):
ctc_op = ctc_op_functor(labels, activation_times)
utt.verify_grad(ctc_op, [activations])
utt.verify_grad(ctc_op, [activations], mode=mode_with_gpu)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论