Add initial implementation of gradients to spatial transformer

上级 73a61cbc
...@@ -2903,7 +2903,7 @@ class GpuDnnTransformer(DnnBase): ...@@ -2903,7 +2903,7 @@ class GpuDnnTransformer(DnnBase):
self.dtype = dtype self.dtype = dtype
def make_node(self, img, theta, output, desc, alpha=None, beta=None): def make_node(self, img, theta, output, desc, alpha=None, beta=None):
context_name = infer_context_name(img) context_name = infer_context_name(desc)
img = gpu_contiguous(as_gpuarray_variable(img, context_name)) img = gpu_contiguous(as_gpuarray_variable(img, context_name))
if img.type.ndim != 4: if img.type.ndim != 4:
...@@ -2938,19 +2938,16 @@ class GpuDnnTransformer(DnnBase): ...@@ -2938,19 +2938,16 @@ class GpuDnnTransformer(DnnBase):
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
img, theta, output, desc, alpha, beta = inputs img, theta, output, desc, alpha, beta = inputs
_, grid = outputs _, grid = outputs
dy = grads[0] dy = grads[0]
dimg, dgrid = GpuDnnTransformerGradI(self.dtype)(img, theta, grid, dy, dimg, dgrid = GpuDnnTransformerGradI(self.dtype)(img, theta, grid, dy,
desc, alpha, beta) desc, alpha, beta)
dtheta = GpuDnnTransformerGradT(self.dtype)(dgrid, desc) dtheta = GpuDnnTransformerGradT(self.dtype)(dgrid, desc)
dalpha = theano.gradient.grad_not_implemented(self, 4, alpha)
dbeta = theano.gradient.grad_not_implemented(self, 5, beta)
return [dimg, dtheta, return [dimg, dtheta, dy, DisconnectedType()(), dalpha, dbeta]
theano.gradient.grad_undefined(self, 2, output),
DisconnectedType()(),
theano.gradient.grad_undefined(self, 4, alpha),
theano.gradient.grad_undefined(self, 5, beta)]
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc # not connected to desc
...@@ -3011,6 +3008,22 @@ class GpuDnnTransformerGradI(DnnBase): ...@@ -3011,6 +3008,22 @@ class GpuDnnTransformerGradI(DnnBase):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def L_op(self, inputs, outputs, grads):
img, theta, grid, grid_dims, dy, desc, alpha, beta = inputs
dimg_out, dgrid = outputs
grad_cost = grads[0]
dimg = dimg_out * grad_cost
dtheta = GpuDnnTransformerGradT(self.dtype)(dgrid, desc)
dgrid_dims = grad_not_implemented(self, grid_dims, 3)
d_dy = grad_not_implemented(self, dy, 4)
dalpha = grad_not_implemented(self, alpha, 5)
dbeta = grad_not_implemented(self, beta, 6)
return [dimg, dtheta, dgrid, dgrid_dims, d_dy,
DisconnectedType()(), dalpha, dbeta]
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc # not connected to desc
return [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [0, 0], [1, 1], [1, 1]] return [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [0, 0], [1, 1], [1, 1]]
...@@ -3039,6 +3052,12 @@ class GpuDnnTransformerGradT(DnnBase): ...@@ -3039,6 +3052,12 @@ class GpuDnnTransformerGradT(DnnBase):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def L_op(self, inputs, outputs, grads):
dgrid, desc = inputs
grad_cost = grads[0]
dtheta = outputs * grad_cost
return [dtheta, DisconnectedType()()]
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc # not connected to desc
return [[1], [0]] return [[1], [0]]
......
...@@ -2508,3 +2508,23 @@ def test_dnn_spatialtf_grad(): ...@@ -2508,3 +2508,23 @@ def test_dnn_spatialtf_grad():
assert any([isinstance(node.op, dnn.GpuDnnTransformerGradT) assert any([isinstance(node.op, dnn.GpuDnnTransformerGradT)
for node in grad_fn.maker.fgraph.toposort()]) for node in grad_fn.maker.fgraph.toposort()])
# Verify grad wrt input
def functor_wrt_i(input):
out = GpuAllocEmpty(theano.config.floatX, context_name=test_ctx_name)(*out_shp)
desc = dnn.GpuDnnTransformerDescriptor(theano.config.floatX)(out_shp)
transformed_input = dnn.GpuDnnTransformer(theano.config.floatX)(input, theta, out, desc)
grad = T.grad(T.mean(transformed_input), input)
return grad
# Verify grad wrt theta
def functor_wrt_t(theta):
out = GpuAllocEmpty(theano.config.floatX, context_name=test_ctx_name)(*out_shp)
desc = dnn.GpuDnnTransformerDescriptor(theano.config.floatX)(out_shp)
transformed_input = dnn.GpuDnnTransformer(theano.config.floatX)(img, theta, out, desc)
grad = T.grad(T.mean(transformed_input), theta)
return grad
utt.verify_grad(functor_wrt_i, [img])
utt.verify_grad(functor_wrt_t, [theta])
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论