Correct flake8 errors in ctc

上级 f995eeec
import numpy as np import numpy as np
import theano
import theano.tensor as T import theano.tensor as T
from theano import config from theano import config
from theano import gof from theano import gof
...@@ -12,6 +11,7 @@ import os ...@@ -12,6 +11,7 @@ import os
ctc_enabled = config.ctc.enabled ctc_enabled = config.ctc.enabled
class ConnectionistTemporalClassification(gof.COp): class ConnectionistTemporalClassification(gof.COp):
__props__ = ('compute_grad',) __props__ = ('compute_grad',)
...@@ -31,8 +31,9 @@ class ConnectionistTemporalClassification(gof.COp): ...@@ -31,8 +31,9 @@ class ConnectionistTemporalClassification(gof.COp):
self.gradients = T.ftensor3(name="ctc_grad") self.gradients = T.ftensor3(name="ctc_grad")
if config.ctc.root == "": if config.ctc.root == "":
raise ValueError("ctc.root variable is not set, please set it " + raise ValueError('ctc.root variable is not set, please set it '
"to the root directory of the CTC library in your system.") 'to the root directory of the CTC library in '
'your system.')
def c_compile_args(self): def c_compile_args(self):
return self.openmp_op.c_compile_args() return self.openmp_op.c_compile_args()
...@@ -43,7 +44,7 @@ class ConnectionistTemporalClassification(gof.COp): ...@@ -43,7 +44,7 @@ class ConnectionistTemporalClassification(gof.COp):
# We assume here that the compiled library (libwarpctc.so) is available # We assume here that the compiled library (libwarpctc.so) is available
# at the build directory of the CTC root directory. # at the build directory of the CTC root directory.
dirs.append(os.path.join(config.ctc.root, "build")) dirs.append(os.path.join(config.ctc.root, "build"))
return dirs return dirs
def c_libraries(self): def c_libraries(self):
return ["warpctc"] return ["warpctc"]
...@@ -67,10 +68,10 @@ class ConnectionistTemporalClassification(gof.COp): ...@@ -67,10 +68,10 @@ class ConnectionistTemporalClassification(gof.COp):
t_activations = T.as_tensor_variable(activations) t_activations = T.as_tensor_variable(activations)
t_labels = T.as_tensor_variable(labels) t_labels = T.as_tensor_variable(labels)
t_input_lengths = T.cast(activations.shape[0], dtype="int32") * \ t_input_lengths = T.cast(activations.shape[0], dtype="int32") * \
T.ones_like(activations[0,:,0], dtype=np.int32) T.ones_like(activations[0, :, 0], dtype=np.int32)
# Return only the cost. Gradient will be returned by grad() # Return only the cost. Gradient will be returned by grad()
self.default_output = 0 self.default_output = 0
return gof.Apply(self, inputs=[t_activations, t_labels, t_input_lengths], return gof.Apply(self, inputs=[t_activations, t_labels, t_input_lengths],
outputs=[self.costs, self.gradients]) outputs=[self.costs, self.gradients])
...@@ -83,20 +84,22 @@ class ConnectionistTemporalClassification(gof.COp): ...@@ -83,20 +84,22 @@ class ConnectionistTemporalClassification(gof.COp):
# self.gradients.shape = [seqLen, batchSize, outputSize] # self.gradients.shape = [seqLen, batchSize, outputSize]
# output_grads[0].shape = [batchSize] (one cost per sequence) # output_grads[0].shape = [batchSize] (one cost per sequence)
# So, reshape output_grads to [1, batchSize, 1] for broadcasting # So, reshape output_grads to [1, batchSize, 1] for broadcasting
output_grad = output_grads[0].reshape( (1, -1, 1) ) output_grad = output_grads[0].reshape((1, -1, 1))
return [output_grad * self.gradients, return [output_grad * self.gradients,
grad_undefined(self, 1, inputs[1]), grad_undefined(self, 1, inputs[1]),
grad_undefined(self, 2, inputs[2])] grad_undefined(self, 2, inputs[2])]
def ctc(activations, labels, input_lengths=None): def ctc(activations, labels, input_lengths=None):
return ConnectionistTemporalClassification()(activations, labels, input_lengths) return ConnectionistTemporalClassification()(activations, labels, input_lengths)
# Disable gradient computation if not needed # Disable gradient computation if not needed
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@local_optimizer([ConnectionistTemporalClassification]) @local_optimizer([ConnectionistTemporalClassification])
def local_ConnectionistTemporalClassification_no_grad(node): def local_ConnectionistTemporalClassification_no_grad(node):
if isinstance(node.op, ConnectionistTemporalClassification): if isinstance(node.op, ConnectionistTemporalClassification):
if len(node.outputs) > 1: if len(node.outputs) > 1:
if len(node.outputs[1].clients) == 0: # gradient is not used if len(node.outputs[1].clients) == 0: # gradient is not used
node.op = ConnectionistTemporalClassification(compute_grad=False) node.op = ConnectionistTemporalClassification(compute_grad=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论