Add initial Python implementation of spatialtf gradients

上级 975a9b7b
#section support_code_struct
int
dnn_sptf_grad(PyGpuArrayObject * input,
PyGpuArrayObject * theta,
PyGpuArrayObject * grid,
PyArrayObject * grid_dims,
PyGpuArrayObject * dy,
cudnnSpatialTransformerDescriptor_t desc,
double alpha, double beta,
PyGpuArrayObject ** output_grad,
PyGpuArrayObject ** grid_grad,
cudnnHandle_t _handle)
{
PyErr_SetString(PyExc_NotImplementedError, "Gradient for spatial transformer is not yet implemented.");
return -1;
}
...@@ -2932,7 +2932,45 @@ class GpuDnnTransformer(DnnBase): ...@@ -2932,7 +2932,45 @@ class GpuDnnTransformer(DnnBase):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
pass img, theta, grid_dims, desc, alpha, beta = inputs
_, grid = outputs
dy = grads[0]
gop = GpuDnnTransformerGrad(self.dtype)(img, theta, grid, grid_dims, dy,
desc, alpha, beta)
return gop
class GpuDnnTransformerGrad(DnnBase):
__props__ = ('dtype',)
_cop_num_inputs = 8
_cop_num_outputs = 2
_f16_ok = True
def __init__(self, dtype=theano.config.floatX):
DnnBase.__init__(self, ["c_code/dnn_sptf_grad.c"], "dnn_sptf_grad")
self.dtype = dtype
def make_node(self, img, theta, grid, grid_dims, dy, desc, alpha, beta):
context_name = infer_context_name(img)
dimg = GpuArrayType(dtype=self.dtype,
broadcastable=img.type.ndim * (False,),
context_name=context_name)()
dtheta = GpuArrayType(dtype=self.dtype,
broadcastable=theta.type.ndim * (False,),
context_name=context_name)()
inputs = [img, theta, grid, grid_dims, dy, desc, alpha, beta]
outputs = [dimg, dtheta,
theano.gradient.grad_undefined(self, 2, grid_dims),
theano.gradient.grad_undefined(self, 3, desc),
theano.gradient.grad_undefined(self, 4, alpha),
theano.gradient.grad_undefined(self, 5, beta)]
return Apply(self, inputs, outputs)
def dnn_spatialtf(inp, theta, scale_width=1, scale_height=1, alpha=None, beta=None, def dnn_spatialtf(inp, theta, scale_width=1, scale_height=1, alpha=None, beta=None,
......
...@@ -2462,6 +2462,14 @@ def test_dnn_spatialtf(): ...@@ -2462,6 +2462,14 @@ def test_dnn_spatialtf():
img_out_gpu, = st_dnn_func(img, transform) img_out_gpu, = st_dnn_func(img, transform)
img_out = np.asarray(img_out_gpu) img_out = np.asarray(img_out_gpu)
t_dy = T.tensor4('dy')
img_grad = T.grad(None, wrt=[t_img, t_theta], known_grads={st_dnn: t_dy})
grad_fn = theano.function([t_img, t_theta, t_dy], img_grad)
dy = -1 + 2 * np.random.randn(*img.shape).astype(theano.config.floatX)
spatialtf_grad = grad_fn(img, transform, dy)
# Check if function graph contains the spatial transformer Ops # Check if function graph contains the spatial transformer Ops
topo = st_dnn_func.maker.fgraph.toposort() topo = st_dnn_func.maker.fgraph.toposort()
assert len([n for n in topo if isinstance(n.op, dnn.GpuDnnTransformer)]) == 1 assert len([n for n in topo if isinstance(n.op, dnn.GpuDnnTransformer)]) == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论