Remove CTC enabled check from CTC wrapper L_op method

上级 9eba634a
...@@ -126,11 +126,6 @@ class GpuConnectionistTemporalClassification(gof.COp): ...@@ -126,11 +126,6 @@ class GpuConnectionistTemporalClassification(gof.COp):
outputs=outputs) outputs=outputs)
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
if not ctc_enabled:
raise RuntimeError('Baidu CTC is not enabled and '
'GpuConnectionistTemporalClassification Op '
'can not be constructed.')
assert len(outputs) == 2
# Gradients computed by Op # Gradients computed by Op
gradients = outputs[1] gradients = outputs[1]
# Gradients of original function, to compose chain rule # Gradients of original function, to compose chain rule
......
...@@ -115,10 +115,6 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): ...@@ -115,10 +115,6 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
outputs=outputs) outputs=outputs)
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
if not ctc_enabled:
raise RuntimeError('Baidu CTC is not enabled and '
'ConnectionistTemporalClassification Op '
'can not be constructed.')
gradients = outputs[1] gradients = outputs[1]
grad_op = output_grads[0] grad_op = output_grads[0]
total_grad = T.basic.batched_dot(grad_op, gradients.dimshuffle(1, 0, 2)).dimshuffle(1, 0, 2) total_grad = T.basic.batched_dot(grad_op, gradients.dimshuffle(1, 0, 2)).dimshuffle(1, 0, 2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论