Change CPU CTC Op to use _cop_num_outputs

上级 e2b07918
...@@ -30,23 +30,27 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): ...@@ -30,23 +30,27 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
""" """
__props__ = ('compute_grad',) __props__ = ('compute_grad',)
_cop_num_inputs = 3
_cop_num_outputs = 2
func_file = "./ctc_wrapper.c" func_file = "./ctc_wrapper.c"
func_name = "APPLY_SPECIFIC(ctc_cost_cpu)" func_name = "APPLY_SPECIFIC(ctc_cost_cpu)"
def __init__(self, compute_grad=True): def __init__(self, compute_grad=True):
if not compute_grad: if not ctc_enabled:
self.func_name = "APPLY_SPECIFIC(ctc_cost_cpu_no_grad)" raise RuntimeError('Baidu CTC is not enabled and '
'ConnectionistTemporalClassification Op '
'can not be constructed.')
elif config.ctc.root == "":
raise ValueError('ctc.root variable is not set, please set it '
'to the root directory of the CTC library in '
'your system.')
gof.COp.__init__(self, self.func_file, self.func_name) gof.COp.__init__(self, self.func_file, self.func_name)
gof.OpenMPOp.__init__(self) gof.OpenMPOp.__init__(self)
self.compute_grad = compute_grad self.compute_grad = compute_grad
if config.ctc.root == "":
raise ValueError('ctc.root variable is not set, please set it '
'to the root directory of the CTC library in '
'your system.')
def c_lib_dirs(self): def c_lib_dirs(self):
dirs = [] dirs = []
if ctc_enabled: if ctc_enabled:
...@@ -83,10 +87,6 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): ...@@ -83,10 +87,6 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
return ["ctc.h"] + gof.OpenMPOp.c_headers(self) return ["ctc.h"] + gof.OpenMPOp.c_headers(self)
def make_node(self, activations, labels, input_lengths): def make_node(self, activations, labels, input_lengths):
if not ctc_enabled:
raise RuntimeError('Baidu CTC is not enabled and '
'ConnectionistTemporalClassification Op '
'can not be constructed.')
t_activations = T.as_tensor_variable(activations) t_activations = T.as_tensor_variable(activations)
# Ensure activations array is C-contiguous # Ensure activations array is C-contiguous
t_activations = cpu_contiguous(t_activations) t_activations = cpu_contiguous(t_activations)
......
...@@ -250,18 +250,3 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations, ...@@ -250,18 +250,3 @@ int APPLY_SPECIFIC(ctc_cost_cpu)(PyArrayObject * in_activations,
return 0; return 0;
} }
/**
* Wrapper version with gradient computation disabled.
*/
int APPLY_SPECIFIC(ctc_cost_cpu_no_grad)(PyArrayObject * in_activations,
PyArrayObject * in_labels,
PyArrayObject * in_input_lengths,
PyArrayObject ** out_costs)
{
return APPLY_SPECIFIC(ctc_cost_cpu)(in_activations,
in_labels,
in_input_lengths,
out_costs,
NULL);
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论