提交 beb09bb1 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2331 from lamblin/fix_set_subtensor1_grad

[BUG] Fix gradient of advanced_set_subtensor1
...@@ -1526,36 +1526,46 @@ class IncSubtensor(Op): ...@@ -1526,36 +1526,46 @@ class IncSubtensor(Op):
x, y = inputs[:2] x, y = inputs[:2]
idx_list = inputs[2:] idx_list = inputs[2:]
if self.set_instead_of_inc: if x.dtype in theano.tensor.discrete_dtypes:
gx = set_subtensor( # The output dtype is the same as x
Subtensor(idx_list=self.idx_list)(g_output, *idx_list), gx = x.zeros_like(dtype=theano.config.floatX)
theano.tensor.zeros_like(y)) if y.dtype in theano.tensor.discrete_dtypes:
gy = y.zeros_like(dtype=theano.config.floatX)
else:
gy = y.zeros_like()
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else: else:
gx = g_output if self.set_instead_of_inc:
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list) gx = set_subtensor(
if gy.broadcastable != y.broadcastable: Subtensor(idx_list=self.idx_list)(g_output, *idx_list),
y_dim_added = gy.ndim - y.ndim theano.tensor.zeros_like(y))
y_broad = (True,) * y_dim_added + y.broadcastable else:
assert sum(gy.broadcastable) < sum(y_broad) gx = g_output
axis_to_sum = [] gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
for i in range(gy.ndim): if gy.broadcastable != y.broadcastable:
if gy.broadcastable[i] is False and y_broad[i] is True: y_dim_added = gy.ndim - y.ndim
axis_to_sum.append(i) y_broad = (True,) * y_dim_added + y.broadcastable
elif (gy.broadcastable[i] is True and assert sum(gy.broadcastable) < sum(y_broad)
y_broad[i] is False): axis_to_sum = []
# This mean that Theano where able to infer that for i in range(gy.ndim):
# gy.shape[i] is 1, so y.shape[i] is 1, but we if gy.broadcastable[i] is False and y_broad[i] is True:
# didn't know it. It is fine. axis_to_sum.append(i)
pass elif (gy.broadcastable[i] is True and
else: y_broad[i] is False):
assert gy.broadcastable[i] == y_broad[i] # This mean that Theano where able to infer that
gy = gy.sum(axis=axis_to_sum, keepdims=True) # gy.shape[i] is 1, so y.shape[i] is 1, but we
if gy.ndim != y.ndim: # didn't know it. It is fine.
assert gy.ndim > y.ndim pass
for i in range(y_dim_added): else:
assert gy.broadcastable[i] assert gy.broadcastable[i] == y_broad[i]
gy = gy.dimshuffle(*range(y_dim_added, gy.ndim)) gy = gy.sum(axis=axis_to_sum, keepdims=True)
assert gy.broadcastable == y.broadcastable if gy.ndim != y.ndim:
assert gy.ndim > y.ndim
for i in range(y_dim_added):
assert gy.broadcastable[i]
gy = gy.dimshuffle(*range(y_dim_added, gy.ndim))
assert gy.broadcastable == y.broadcastable
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
...@@ -1868,15 +1878,30 @@ class AdvancedIncSubtensor1(Op): ...@@ -1868,15 +1878,30 @@ class AdvancedIncSubtensor1(Op):
def grad(self, inputs, grads): def grad(self, inputs, grads):
g_output, = grads g_output, = grads
x, y = inputs[:2] x, y, idx_list = inputs
idx_list = inputs[2:] if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = g_output gx = x.zeros_like(dtype=theano.config.floatX)
gy = advanced_subtensor1(g_output, *idx_list) if y.dtype in theano.tensor.discrete_dtypes:
gy = y.zeros_like(dtype=theano.config.floatX)
else:
gy = y.zeros_like()
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
if self.set_instead_of_inc:
gx = advanced_set_subtensor1(
g_output,
y.zeros_like(),
idx_list)
else:
gx = g_output
gy = advanced_subtensor1(g_output, idx_list)
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()]
advanced_inc_subtensor1 = AdvancedIncSubtensor1() advanced_inc_subtensor1 = AdvancedIncSubtensor1()
advanced_set_subtensor1 = AdvancedIncSubtensor1(set_instead_of_inc=True)
def as_index_variable(idx): def as_index_variable(idx):
...@@ -2079,9 +2104,13 @@ class AdvancedIncSubtensor(Op): ...@@ -2079,9 +2104,13 @@ class AdvancedIncSubtensor(Op):
'later, or to the latest development version. ' 'later, or to the latest development version. '
'You may need to clear the cache (theano-cache clear) ' 'You may need to clear the cache (theano-cache clear) '
'afterwards.') 'afterwards.')
new_inputs = []
for inp in inputs:
if isinstance(inp, (list, tuple)):
inp = theano.tensor.as_tensor_variable(inp)
new_inputs.append(inp)
return gof.Apply(op, return gof.Apply(op,
(x, y) + inputs, (x, y) + tuple(new_inputs),
[theano.tensor.tensor( [theano.tensor.tensor(
dtype=x.type.dtype, dtype=x.type.dtype,
broadcastable=x.type.broadcastable)]) broadcastable=x.type.broadcastable)])
...@@ -2136,9 +2165,25 @@ class AdvancedIncSubtensor(Op): ...@@ -2136,9 +2165,25 @@ class AdvancedIncSubtensor(Op):
x, y = inpt[:2] x, y = inpt[:2]
idxs = inpt[2:] idxs = inpt[2:]
outgrad, = output_gradients outgrad, = output_gradients
d_x_wrt_C = outgrad if x.dtype in theano.tensor.discrete_dtypes:
d_y_wrt_C = AdvancedSubtensor()(outgrad, *idxs) # The output dtype is the same as x
return [d_x_wrt_C, d_y_wrt_C] + \ gx = x.zeros_like(dtype=theano.config.floatX)
if y.dtype in theano.tensor.discrete_dtypes:
gy = y.zeros_like(dtype=theano.config.floatX)
else:
gy = y.zeros_like()
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
if self.set_instead_of_inc:
gx = advanced_set_subtensor(
outgrad,
y.zeros_like(),
*idxs)
else:
gx = outgrad
gy = advanced_subtensor(outgrad, *idxs)
return [gx, gy] + \
[DisconnectedType()() for _ in idxs] [DisconnectedType()() for _ in idxs]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -2147,6 +2192,7 @@ class AdvancedIncSubtensor(Op): ...@@ -2147,6 +2192,7 @@ class AdvancedIncSubtensor(Op):
return self.make_node(eval_points[0], eval_points[1], return self.make_node(eval_points[0], eval_points[1],
*inputs[2:]).outputs *inputs[2:]).outputs
advanced_inc_subtensor = AdvancedIncSubtensor() advanced_inc_subtensor = AdvancedIncSubtensor()
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
def take(a, indices, axis=None, mode='raise'): def take(a, indices, axis=None, mode='raise'):
......
...@@ -18,6 +18,10 @@ import theano.scalar as scal ...@@ -18,6 +18,10 @@ import theano.scalar as scal
import theano.tensor as tensor import theano.tensor as tensor
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.subtensor import (inc_subtensor, set_subtensor, from theano.tensor.subtensor import (inc_subtensor, set_subtensor,
advanced_inc_subtensor1,
advanced_set_subtensor1,
advanced_inc_subtensor,
advanced_set_subtensor,
Subtensor, IncSubtensor, Subtensor, IncSubtensor,
AdvancedSubtensor1, AdvancedSubtensor, AdvancedSubtensor1, AdvancedSubtensor,
advanced_subtensor1, inplace_increment, advanced_subtensor1, inplace_increment,
...@@ -519,6 +523,19 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -519,6 +523,19 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self.assertTrue(g_00.shape == (1, 3)) self.assertTrue(g_00.shape == (1, 3))
self.assertTrue(numpy.allclose(g_00, 2)) self.assertTrue(numpy.allclose(g_00, 2))
utt.verify_grad(lambda m: m[[1, 3]],
[numpy.random.rand(5, 5).astype(self.dtype)])
def fun(x, y):
return advanced_inc_subtensor1(x, y, [1, 3])
utt.verify_grad(fun, [numpy.random.rand(5, 5).astype(self.dtype),
numpy.random.rand(2, 5).astype(self.dtype)])
def fun(x, y):
return advanced_set_subtensor1(x, y, [1, 3])
utt.verify_grad(fun, [numpy.random.rand(5, 5).astype(self.dtype),
numpy.random.rand(2, 5).astype(self.dtype)])
def test_adv_sub1_idx_broadcast(self): def test_adv_sub1_idx_broadcast(self):
# The idx can be a broadcastable vector. # The idx can be a broadcastable vector.
ones = numpy.ones((4, 3), dtype=self.dtype) ones = numpy.ones((4, 3), dtype=self.dtype)
...@@ -1366,6 +1383,27 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -1366,6 +1383,27 @@ class TestAdvancedSubtensor(unittest.TestCase):
cmd = f2(0, 1, 2) == aa[[0, 1, 2], :, 0:2] cmd = f2(0, 1, 2) == aa[[0, 1, 2], :, 0:2]
self.assertTrue(cmd.all()) self.assertTrue(cmd.all())
def test_grad(self):
ones = numpy.ones((1, 3), dtype=self.dtype)
n = self.shared(ones * 5, broadcastable=(True, False))
idx = tensor.lvector()
idx2 = tensor.lvector()
t = n[idx, idx2]
self.assertTrue(isinstance(t.owner.op, tensor.AdvancedSubtensor))
utt.verify_grad(lambda m: m[[1, 3], [2, 4]],
[numpy.random.rand(5, 5).astype(self.dtype)])
def fun(x, y):
return advanced_inc_subtensor(x, y, [1, 3], [2, 4])
utt.verify_grad(fun, [numpy.random.rand(5, 5).astype(self.dtype),
numpy.random.rand(2).astype(self.dtype)])
def fun(x, y):
return advanced_set_subtensor(x, y, [1, 3], [2, 4])
utt.verify_grad(fun, [numpy.random.rand(5, 5).astype(self.dtype),
numpy.random.rand(2).astype(self.dtype)])
class TestInferShape(utt.InferShapeTester): class TestInferShape(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论