Use outputs list to set CTC Op's outputs in make_node

上级 d8c66717
......@@ -109,15 +109,16 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
raise TypeError('Label lengths must use the int32 type!')
costs = T.fvector(name="ctc_cost")
outputs = [costs]
if self.compute_grad:
gradients = T.ftensor3(name="ctc_grad")
outputs += [gradients]
# Return only the cost. Gradient will be returned by grad()
self.default_output = 0
return gof.Apply(self, inputs=[t_activations, t_labels, t_input_lengths],
outputs=[costs, gradients])
outputs=outputs)
def L_op(self, inputs, outputs, output_grads):
if not ctc_enabled:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论