Add initial implementation of gradients to spatial transformer

上级 73a61cbc
......@@ -2903,7 +2903,7 @@ class GpuDnnTransformer(DnnBase):
self.dtype = dtype
def make_node(self, img, theta, output, desc, alpha=None, beta=None):
context_name = infer_context_name(img)
context_name = infer_context_name(desc)
img = gpu_contiguous(as_gpuarray_variable(img, context_name))
if img.type.ndim != 4:
......@@ -2938,19 +2938,16 @@ class GpuDnnTransformer(DnnBase):
def L_op(self, inputs, outputs, grads):
img, theta, output, desc, alpha, beta = inputs
_, grid = outputs
dy = grads[0]
dimg, dgrid = GpuDnnTransformerGradI(self.dtype)(img, theta, grid, dy,
desc, alpha, beta)
dtheta = GpuDnnTransformerGradT(self.dtype)(dgrid, desc)
dalpha = theano.gradient.grad_not_implemented(self, 4, alpha)
dbeta = theano.gradient.grad_not_implemented(self, 5, beta)
return [dimg, dtheta,
theano.gradient.grad_undefined(self, 2, output),
DisconnectedType()(),
theano.gradient.grad_undefined(self, 4, alpha),
theano.gradient.grad_undefined(self, 5, beta)]
return [dimg, dtheta, dy, DisconnectedType()(), dalpha, dbeta]
def connection_pattern(self, node):
# not connected to desc
......@@ -3011,6 +3008,22 @@ class GpuDnnTransformerGradI(DnnBase):
return Apply(self, inputs, outputs)
def L_op(self, inputs, outputs, grads):
img, theta, grid, grid_dims, dy, desc, alpha, beta = inputs
dimg_out, dgrid = outputs
grad_cost = grads[0]
dimg = dimg_out * grad_cost
dtheta = GpuDnnTransformerGradT(self.dtype)(dgrid, desc)
dgrid_dims = grad_not_implemented(self, grid_dims, 3)
d_dy = grad_not_implemented(self, dy, 4)
dalpha = grad_not_implemented(self, alpha, 5)
dbeta = grad_not_implemented(self, beta, 6)
return [dimg, dtheta, dgrid, dgrid_dims, d_dy,
DisconnectedType()(), dalpha, dbeta]
def connection_pattern(self, node):
# not connected to desc
return [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [0, 0], [1, 1], [1, 1]]
......@@ -3039,6 +3052,12 @@ class GpuDnnTransformerGradT(DnnBase):
return Apply(self, inputs, outputs)
def L_op(self, inputs, outputs, grads):
dgrid, desc = inputs
grad_cost = grads[0]
dtheta = outputs * grad_cost
return [dtheta, DisconnectedType()()]
def connection_pattern(self, node):
# not connected to desc
return [[1], [0]]
......
......@@ -2508,3 +2508,23 @@ def test_dnn_spatialtf_grad():
assert any([isinstance(node.op, dnn.GpuDnnTransformerGradT)
for node in grad_fn.maker.fgraph.toposort()])
# Verify grad wrt input
def functor_wrt_i(input):
out = GpuAllocEmpty(theano.config.floatX, context_name=test_ctx_name)(*out_shp)
desc = dnn.GpuDnnTransformerDescriptor(theano.config.floatX)(out_shp)
transformed_input = dnn.GpuDnnTransformer(theano.config.floatX)(input, theta, out, desc)
grad = T.grad(T.mean(transformed_input), input)
return grad
# Verify grad wrt theta
def functor_wrt_t(theta):
out = GpuAllocEmpty(theano.config.floatX, context_name=test_ctx_name)(*out_shp)
desc = dnn.GpuDnnTransformerDescriptor(theano.config.floatX)(out_shp)
transformed_input = dnn.GpuDnnTransformer(theano.config.floatX)(img, theta, out, desc)
grad = T.grad(T.mean(transformed_input), theta)
return grad
utt.verify_grad(functor_wrt_i, [img])
utt.verify_grad(functor_wrt_t, [theta])
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论