Fix CTC gradient disabling optimization

* Optimization returns new outputs instead of modifying the graph node inplace * Add default_output as class property Signed-off-by: 's avatarJoão Victor Tozatti Risso <joaovictor.risso@gmail.com>
上级 416a7299
......@@ -28,7 +28,7 @@ class GpuConnectionistTemporalClassification(gof.COp):
compute_grad
If set to True, enables the computation of gradients of the CTC loss function.
"""
__props__ = ('compute_grad',)
__props__ = ('compute_grad', 'default_output',)
_cop_num_inputs = 3
_cop_num_outputs = 2
......@@ -39,15 +39,21 @@ class GpuConnectionistTemporalClassification(gof.COp):
params_type = gpu_context_type
def __init__(self, compute_grad=True):
self.compute_grad = compute_grad
gof.COp.__init__(self, self.func_file, self.func_name)
if config.ctc.root == "":
if not ctc_enabled:
raise RuntimeError('Baidu CTC is not enabled and '
'GpuConnectionistTemporalClassification Op '
'can not be constructed.')
elif config.ctc.root == "":
raise ValueError('ctc.root variable is not set, please set it '
'to the root directory of the CTC library in '
'your system.')
self.compute_grad = compute_grad
# Return only the cost. Gradient will be returned by grad()
self.default_output = 0
gof.COp.__init__(self, self.func_file, self.func_name)
def c_lib_dirs(self):
dirs = []
if ctc_enabled:
......@@ -86,13 +92,7 @@ class GpuConnectionistTemporalClassification(gof.COp):
return node.inputs[0].type.context
def make_node(self, activations, labels, input_lengths):
if not ctc_enabled:
raise RuntimeError('Baidu CTC is not enabled and '
'GpuConnectionistTemporalClassification Op '
'can not be constructed.')
context_name = infer_context_name(activations)
t_activations = as_gpuarray_variable(activations,
context_name=context_name)
# Ensure activations array is C-contiguous
......@@ -111,9 +111,6 @@ class GpuConnectionistTemporalClassification(gof.COp):
if t_input_lengths.type.dtype != 'int32':
raise TypeError('Label lengths must use the int32 type!')
# Return only the cost. Gradient will be returned by grad()
self.default_output = 0
costs = GpuArrayType(dtype='float32',
broadcastable=(False,),
context_name=context_name)()
......@@ -133,6 +130,7 @@ class GpuConnectionistTemporalClassification(gof.COp):
raise RuntimeError('Baidu CTC is not enabled and '
'GpuConnectionistTemporalClassification Op '
'can not be constructed.')
assert len(outputs) == 2
# Gradients computed by Op
gradients = outputs[1]
# Gradients of original function, to compose chain rule
......@@ -171,8 +169,7 @@ def gpu_ctc(activations, labels, input_lengths):
1-D array
Cost of each example in the minibatch.
"""
return GpuConnectionistTemporalClassification()(activations, labels,
input_lengths)
return GpuConnectionistTemporalClassification()(activations, labels, input_lengths)
# Disable gradient computation if not needed
......@@ -183,5 +180,5 @@ def local_gpu_ctc_no_grad(node):
if isinstance(node.op, GpuConnectionistTemporalClassification):
if len(node.outputs) > 1:
if len(node.outputs[1].clients) == 0: # gradient is not used
node.op = GpuConnectionistTemporalClassification(compute_grad=False)
node.outputs = node.outputs[:1] # costs only
node.op.compute_grad = False
return [GpuConnectionistTemporalClassification(compute_grad=False)(*node.inputs)]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论