Correct flake8 errors in ctc

上级 f995eeec
import numpy as np import numpy as np
import theano
import theano.tensor as T import theano.tensor as T
from theano import config from theano import config
from theano import gof from theano import gof
...@@ -12,6 +11,7 @@ import os ...@@ -12,6 +11,7 @@ import os
ctc_enabled = config.ctc.enabled ctc_enabled = config.ctc.enabled
class ConnectionistTemporalClassification(gof.COp): class ConnectionistTemporalClassification(gof.COp):
__props__ = ('compute_grad',) __props__ = ('compute_grad',)
...@@ -31,8 +31,9 @@ class ConnectionistTemporalClassification(gof.COp): ...@@ -31,8 +31,9 @@ class ConnectionistTemporalClassification(gof.COp):
self.gradients = T.ftensor3(name="ctc_grad") self.gradients = T.ftensor3(name="ctc_grad")
if config.ctc.root == "": if config.ctc.root == "":
raise ValueError("ctc.root variable is not set, please set it " + raise ValueError('ctc.root variable is not set, please set it '
"to the root directory of the CTC library in your system.") 'to the root directory of the CTC library in '
'your system.')
def c_compile_args(self): def c_compile_args(self):
return self.openmp_op.c_compile_args() return self.openmp_op.c_compile_args()
...@@ -67,7 +68,7 @@ class ConnectionistTemporalClassification(gof.COp): ...@@ -67,7 +68,7 @@ class ConnectionistTemporalClassification(gof.COp):
t_activations = T.as_tensor_variable(activations) t_activations = T.as_tensor_variable(activations)
t_labels = T.as_tensor_variable(labels) t_labels = T.as_tensor_variable(labels)
t_input_lengths = T.cast(activations.shape[0], dtype="int32") * \ t_input_lengths = T.cast(activations.shape[0], dtype="int32") * \
T.ones_like(activations[0,:,0], dtype=np.int32) T.ones_like(activations[0, :, 0], dtype=np.int32)
# Return only the cost. Gradient will be returned by grad() # Return only the cost. Gradient will be returned by grad()
self.default_output = 0 self.default_output = 0
...@@ -83,14 +84,16 @@ class ConnectionistTemporalClassification(gof.COp): ...@@ -83,14 +84,16 @@ class ConnectionistTemporalClassification(gof.COp):
# self.gradients.shape = [seqLen, batchSize, outputSize] # self.gradients.shape = [seqLen, batchSize, outputSize]
# output_grads[0].shape = [batchSize] (one cost per sequence) # output_grads[0].shape = [batchSize] (one cost per sequence)
# So, reshape output_grads to [1, batchSize, 1] for broadcasting # So, reshape output_grads to [1, batchSize, 1] for broadcasting
output_grad = output_grads[0].reshape( (1, -1, 1) ) output_grad = output_grads[0].reshape((1, -1, 1))
return [output_grad * self.gradients, return [output_grad * self.gradients,
grad_undefined(self, 1, inputs[1]), grad_undefined(self, 1, inputs[1]),
grad_undefined(self, 2, inputs[2])] grad_undefined(self, 2, inputs[2])]
def ctc(activations, labels, input_lengths=None): def ctc(activations, labels, input_lengths=None):
return ConnectionistTemporalClassification()(activations, labels, input_lengths) return ConnectionistTemporalClassification()(activations, labels, input_lengths)
# Disable gradient computation if not needed # Disable gradient computation if not needed
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论