Fix CTC gradient disabling optimization

* Optimization returns new outputs instead of modifying the graph node inplace * Add default_output as class property Signed-off-by: 's avatarJoão Victor Tozatti Risso <joaovictor.risso@gmail.com>
上级 416a7299
...@@ -28,7 +28,7 @@ class GpuConnectionistTemporalClassification(gof.COp): ...@@ -28,7 +28,7 @@ class GpuConnectionistTemporalClassification(gof.COp):
compute_grad compute_grad
If set to True, enables the computation of gradients of the CTC loss function. If set to True, enables the computation of gradients of the CTC loss function.
""" """
__props__ = ('compute_grad',) __props__ = ('compute_grad', 'default_output',)
_cop_num_inputs = 3 _cop_num_inputs = 3
_cop_num_outputs = 2 _cop_num_outputs = 2
...@@ -39,15 +39,21 @@ class GpuConnectionistTemporalClassification(gof.COp): ...@@ -39,15 +39,21 @@ class GpuConnectionistTemporalClassification(gof.COp):
params_type = gpu_context_type params_type = gpu_context_type
def __init__(self, compute_grad=True): def __init__(self, compute_grad=True):
self.compute_grad = compute_grad if not ctc_enabled:
raise RuntimeError('Baidu CTC is not enabled and '
gof.COp.__init__(self, self.func_file, self.func_name) 'GpuConnectionistTemporalClassification Op '
'can not be constructed.')
if config.ctc.root == "": elif config.ctc.root == "":
raise ValueError('ctc.root variable is not set, please set it ' raise ValueError('ctc.root variable is not set, please set it '
'to the root directory of the CTC library in ' 'to the root directory of the CTC library in '
'your system.') 'your system.')
self.compute_grad = compute_grad
# Return only the cost. Gradient will be returned by grad()
self.default_output = 0
gof.COp.__init__(self, self.func_file, self.func_name)
def c_lib_dirs(self): def c_lib_dirs(self):
dirs = [] dirs = []
if ctc_enabled: if ctc_enabled:
...@@ -86,13 +92,7 @@ class GpuConnectionistTemporalClassification(gof.COp): ...@@ -86,13 +92,7 @@ class GpuConnectionistTemporalClassification(gof.COp):
return node.inputs[0].type.context return node.inputs[0].type.context
def make_node(self, activations, labels, input_lengths): def make_node(self, activations, labels, input_lengths):
if not ctc_enabled:
raise RuntimeError('Baidu CTC is not enabled and '
'GpuConnectionistTemporalClassification Op '
'can not be constructed.')
context_name = infer_context_name(activations) context_name = infer_context_name(activations)
t_activations = as_gpuarray_variable(activations, t_activations = as_gpuarray_variable(activations,
context_name=context_name) context_name=context_name)
# Ensure activations array is C-contiguous # Ensure activations array is C-contiguous
...@@ -111,9 +111,6 @@ class GpuConnectionistTemporalClassification(gof.COp): ...@@ -111,9 +111,6 @@ class GpuConnectionistTemporalClassification(gof.COp):
if t_input_lengths.type.dtype != 'int32': if t_input_lengths.type.dtype != 'int32':
raise TypeError('Label lengths must use the int32 type!') raise TypeError('Label lengths must use the int32 type!')
# Return only the cost. Gradient will be returned by grad()
self.default_output = 0
costs = GpuArrayType(dtype='float32', costs = GpuArrayType(dtype='float32',
broadcastable=(False,), broadcastable=(False,),
context_name=context_name)() context_name=context_name)()
...@@ -133,6 +130,7 @@ class GpuConnectionistTemporalClassification(gof.COp): ...@@ -133,6 +130,7 @@ class GpuConnectionistTemporalClassification(gof.COp):
raise RuntimeError('Baidu CTC is not enabled and ' raise RuntimeError('Baidu CTC is not enabled and '
'GpuConnectionistTemporalClassification Op ' 'GpuConnectionistTemporalClassification Op '
'can not be constructed.') 'can not be constructed.')
assert len(outputs) == 2
# Gradients computed by Op # Gradients computed by Op
gradients = outputs[1] gradients = outputs[1]
# Gradients of original function, to compose chain rule # Gradients of original function, to compose chain rule
...@@ -171,8 +169,7 @@ def gpu_ctc(activations, labels, input_lengths): ...@@ -171,8 +169,7 @@ def gpu_ctc(activations, labels, input_lengths):
1-D array 1-D array
Cost of each example in the minibatch. Cost of each example in the minibatch.
""" """
return GpuConnectionistTemporalClassification()(activations, labels, return GpuConnectionistTemporalClassification()(activations, labels, input_lengths)
input_lengths)
# Disable gradient computation if not needed # Disable gradient computation if not needed
...@@ -183,5 +180,5 @@ def local_gpu_ctc_no_grad(node): ...@@ -183,5 +180,5 @@ def local_gpu_ctc_no_grad(node):
if isinstance(node.op, GpuConnectionistTemporalClassification): if isinstance(node.op, GpuConnectionistTemporalClassification):
if len(node.outputs) > 1: if len(node.outputs) > 1:
if len(node.outputs[1].clients) == 0: # gradient is not used if len(node.outputs[1].clients) == 0: # gradient is not used
node.op = GpuConnectionistTemporalClassification(compute_grad=False) node.op.compute_grad = False
node.outputs = node.outputs[:1] # costs only return [GpuConnectionistTemporalClassification(compute_grad=False)(*node.inputs)]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论