Add local optimization to disable gradient computation in CTC GPU wrapper

上级 d9c5f4fe
......@@ -5,12 +5,17 @@ import theano
from theano import Op
from theano import config
import theano.tensor as T
from theano.tensor.extra_ops import cpu_contiguous
from .basic_ops import (gpu_contiguous, as_gpuarray_variable,
infer_context_name, CGpuKernelBase)
import theano.tensor.nnet.ctc
from .type import GpuArrayType
from .opt import register_opt, op_lifter, register_opt2
from theano.gradient import grad_undefined
from theano import gof
from theano.gof import local_optimizer
from theano.tensor.opt import register_canonicalize
from theano.tensor.opt import register_stabilize
import os
import pygpu
......@@ -36,14 +41,14 @@ class GpuConnectionistTemporalClassification(CGpuKernelBase, Op):
Op.__init__(self)
CGpuKernelBase.__init__(self, self.func_file, self.func_name)
self.costs_type = GpuArrayType(dtype='float32',
broadcastable=(False,),
context_name=self.context_name)
self.costs = GpuArrayType(dtype='float32',
broadcastable=(False,),
context_name=self.context_name)
if self.compute_grad:
self.grads_type = GpuArrayType(dtype='float32',
broadcastable=(False, False, False,),
context_name=self.context_name)
self.gradients = GpuArrayType(dtype='float32',
broadcastable=(False, False, False,),
context_name=self.context_name)
if config.ctc.root == "":
raise ValueError('ctc.root variable is not set, please set it '
......@@ -67,15 +72,10 @@ class GpuConnectionistTemporalClassification(CGpuKernelBase, Op):
# We assume here that the header is available at the include directory
# of the CTC root directory.
dirs.append(os.path.join(config.ctc.root, "include"))
dirs = dirs + list(pygpu.get_include())
dirs = dirs + list(super(CGpuKernelBase, self).c_header_dirs())
return dirs
def c_headers(self):
headers = ['ctc.h']
headers = headers + super(CGpuKernelBase, self).c_headers()
headers = headers + ['<numpy_compat.h>', '<gpuarray_helper.h>']
return headers
return ['ctc.h']
def make_node(self, activations, labels, input_lengths):
if not ctc_enabled:
......@@ -91,9 +91,9 @@ class GpuConnectionistTemporalClassification(CGpuKernelBase, Op):
# Ensure activations array is C-contiguous
t_activations = gpu_contiguous(t_activations)
t_labels = as_gpuarray_variable(labels, context_name=self.context_name)
t_input_lengths = as_gpuarray_variable(input_lengths,
context_name=self.context_name)
# Labels and input lengths are always on the CPU
t_labels = T.as_tensor_variable(labels)
t_input_lengths = T.as_tensor_variable(input_lengths)
if t_activations.type.dtype != 'float32':
raise TypeError('Activations must use the float32 type!')
......@@ -107,18 +107,30 @@ class GpuConnectionistTemporalClassification(CGpuKernelBase, Op):
# Return only the cost. Gradient will be returned by grad()
self.default_output = 0
out_params = [self.costs_type()]
if self.grads_type is not None:
out_params.append(self.grads_type())
out_params = [as_gpuarray_variable(self.costs(), context_name=self.context_name)]
if self.gradients is not None:
out_params.append(as_gpuarray_variable(self.gradients(),
context_name=self.context_name))
return theano.Apply(self, inputs=[t_activations, t_labels, t_input_lengths],
outputs=out_params)
def grad(self, inputs, output_grads):
return [self.grads_type(),
def grad(self, inputs, grads):
return [as_gpuarray_variable(self.gradients(), context_name=self.context_name),
grad_undefined(self, 1, inputs[1]),
grad_undefined(self, 2, inputs[2])]
def ctc(activations, labels, input_lengths):
return GpuConnectionistTemporalClassification()(activations, labels,
input_lengths)
# Disable gradient computation if not needed
@register_canonicalize
@register_stabilize
@local_optimizer([GpuConnectionistTemporalClassification])
def local_GpuConnectionistTemporalClassification_no_grad(node):
if isinstance(node.op, GpuConnectionistTemporalClassification):
if len(node.outputs) > 1:
if len(node.outputs[1].clients) == 0: # gradient is not used
node.op = GpuConnectionistTemporalClassification(compute_grad=False)
node.outputs = node.outputs[:1] # costs only
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论