Add openmp flag to CTC Op constructor

上级 7e117325
...@@ -108,14 +108,14 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp): ...@@ -108,14 +108,14 @@ class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
func_file = "./ctc_wrapper.c" func_file = "./ctc_wrapper.c"
func_name = "APPLY_SPECIFIC(ctc_cost_cpu)" func_name = "APPLY_SPECIFIC(ctc_cost_cpu)"
def __init__(self, compute_grad=True): def __init__(self, compute_grad=True, openmp=None):
if not ctc_available(): if not ctc_available():
raise RuntimeError('Baidu CTC is not available and ' raise RuntimeError('Baidu CTC is not available and '
'ConnectionistTemporalClassification Op ' 'ConnectionistTemporalClassification Op '
'can not be constructed.') 'can not be constructed.')
gof.COp.__init__(self, self.func_file, self.func_name) gof.COp.__init__(self, self.func_file, self.func_name)
gof.OpenMPOp.__init__(self) gof.OpenMPOp.__init__(self, openmp=openmp)
self.compute_grad = compute_grad self.compute_grad = compute_grad
# Return only the cost. Gradient will be returned by grad() # Return only the cost. Gradient will be returned by grad()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论