Add OpenMP support to CTC wrapper

上级 595143f7
......@@ -12,7 +12,7 @@ from theano.gradient import grad_undefined
ctc_enabled = config.ctc.enabled
class ConnectionistTemporalClassification(gof.COp):
class ConnectionistTemporalClassification(gof.COp, gof.OpenMPOp):
__props__ = ('compute_grad',)
func_file = "./ctc_wrapper.c"
......@@ -22,8 +22,8 @@ class ConnectionistTemporalClassification(gof.COp):
if not compute_grad:
self.func_name = "APPLY_SPECIFIC(ctc_cost_cpu_no_grad)"
super(ConnectionistTemporalClassification, self).__init__(self.func_file,
self.func_name)
gof.COp.__init__(self, self.func_file, self.func_name)
gof.OpenMPOp.__init__(self)
self.compute_grad = compute_grad
self.costs = T.fvector(name="ctc_cost")
......@@ -54,8 +54,11 @@ class ConnectionistTemporalClassification(gof.COp):
dirs.append(os.path.join(config.ctc.root, "include"))
return dirs
def c_compile_args(self):
return gof.OpenMPOp.c_compile_args(self)
def c_headers(self):
return ["ctc.h"]
return ["ctc.h"] + gof.OpenMPOp.c_headers(self)
def make_node(self, activations, labels, input_lengths):
if not ctc_enabled:
......
......@@ -13,8 +13,11 @@ void ctc_context_init(ctc_context_t * context)
struct ctcOptions * options = &(context->options);
memset(options, 0, sizeof(struct ctcOptions));
options->loc = CTC_CPU;
#if defined(_OPENMP)
options->num_threads = omp_get_num_threads();
#else
options->num_threads = 1;
#endif
context->workspace = NULL;
context->input_lengths = NULL;
context->flat_labels = NULL;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论