提交 ff9c4bdc authored 作者: Pascal Lamblin's avatar Pascal Lamblin

correct handling of int/complex dtypes in IncSubtensor.grad

上级 ef079e7a
......@@ -1471,36 +1471,46 @@ class IncSubtensor(Op):
x, y = inputs[:2]
idx_list = inputs[2:]
if self.set_instead_of_inc:
gx = set_subtensor(
Subtensor(idx_list=self.idx_list)(g_output, *idx_list),
theano.tensor.zeros_like(y))
if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=theano.config.floatX)
if y.dtype in theano.tensor.discrete_dtypes:
gy = y.zeros_like(dtype=theano.config.floatX)
else:
gy = y.zeros_like()
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
if gy.broadcastable != y.broadcastable:
y_dim_added = gy.ndim - y.ndim
y_broad = (True,) * y_dim_added + y.broadcastable
assert sum(gy.broadcastable) < sum(y_broad)
axis_to_sum = []
for i in range(gy.ndim):
if gy.broadcastable[i] is False and y_broad[i] is True:
axis_to_sum.append(i)
elif (gy.broadcastable[i] is True and
y_broad[i] is False):
# This mean that Theano where able to infer that
# gy.shape[i] is 1, so y.shape[i] is 1, but we
# didn't know it. It is fine.
pass
else:
assert gy.broadcastable[i] == y_broad[i]
gy = gy.sum(axis=axis_to_sum, keepdims=True)
if gy.ndim != y.ndim:
assert gy.ndim > y.ndim
for i in range(y_dim_added):
assert gy.broadcastable[i]
gy = gy.dimshuffle(*range(y_dim_added, gy.ndim))
assert gy.broadcastable == y.broadcastable
if self.set_instead_of_inc:
gx = set_subtensor(
Subtensor(idx_list=self.idx_list)(g_output, *idx_list),
theano.tensor.zeros_like(y))
else:
gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
if gy.broadcastable != y.broadcastable:
y_dim_added = gy.ndim - y.ndim
y_broad = (True,) * y_dim_added + y.broadcastable
assert sum(gy.broadcastable) < sum(y_broad)
axis_to_sum = []
for i in range(gy.ndim):
if gy.broadcastable[i] is False and y_broad[i] is True:
axis_to_sum.append(i)
elif (gy.broadcastable[i] is True and
y_broad[i] is False):
# This mean that Theano where able to infer that
# gy.shape[i] is 1, so y.shape[i] is 1, but we
# didn't know it. It is fine.
pass
else:
assert gy.broadcastable[i] == y_broad[i]
gy = gy.sum(axis=axis_to_sum, keepdims=True)
if gy.ndim != y.ndim:
assert gy.ndim > y.ndim
for i in range(y_dim_added):
assert gy.broadcastable[i]
gy = gy.dimshuffle(*range(y_dim_added, gy.ndim))
assert gy.broadcastable == y.broadcastable
return [gx, gy] + [DisconnectedType()()] * len(idx_list)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论