提交 ff9c4bdc authored 作者: Pascal Lamblin's avatar Pascal Lamblin

correct handling of int/complex dtypes in IncSubtensor.grad

上级 ef079e7a
...@@ -1471,36 +1471,46 @@ class IncSubtensor(Op): ...@@ -1471,36 +1471,46 @@ class IncSubtensor(Op):
x, y = inputs[:2] x, y = inputs[:2]
idx_list = inputs[2:] idx_list = inputs[2:]
if self.set_instead_of_inc: if x.dtype in theano.tensor.discrete_dtypes:
gx = set_subtensor( # The output dtype is the same as x
Subtensor(idx_list=self.idx_list)(g_output, *idx_list), gx = x.zeros_like(dtype=theano.config.floatX)
theano.tensor.zeros_like(y)) if y.dtype in theano.tensor.discrete_dtypes:
gy = y.zeros_like(dtype=theano.config.floatX)
else:
gy = y.zeros_like()
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else: else:
gx = g_output if self.set_instead_of_inc:
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list) gx = set_subtensor(
if gy.broadcastable != y.broadcastable: Subtensor(idx_list=self.idx_list)(g_output, *idx_list),
y_dim_added = gy.ndim - y.ndim theano.tensor.zeros_like(y))
y_broad = (True,) * y_dim_added + y.broadcastable else:
assert sum(gy.broadcastable) < sum(y_broad) gx = g_output
axis_to_sum = [] gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
for i in range(gy.ndim): if gy.broadcastable != y.broadcastable:
if gy.broadcastable[i] is False and y_broad[i] is True: y_dim_added = gy.ndim - y.ndim
axis_to_sum.append(i) y_broad = (True,) * y_dim_added + y.broadcastable
elif (gy.broadcastable[i] is True and assert sum(gy.broadcastable) < sum(y_broad)
y_broad[i] is False): axis_to_sum = []
# This mean that Theano where able to infer that for i in range(gy.ndim):
# gy.shape[i] is 1, so y.shape[i] is 1, but we if gy.broadcastable[i] is False and y_broad[i] is True:
# didn't know it. It is fine. axis_to_sum.append(i)
pass elif (gy.broadcastable[i] is True and
else: y_broad[i] is False):
assert gy.broadcastable[i] == y_broad[i] # This mean that Theano where able to infer that
gy = gy.sum(axis=axis_to_sum, keepdims=True) # gy.shape[i] is 1, so y.shape[i] is 1, but we
if gy.ndim != y.ndim: # didn't know it. It is fine.
assert gy.ndim > y.ndim pass
for i in range(y_dim_added): else:
assert gy.broadcastable[i] assert gy.broadcastable[i] == y_broad[i]
gy = gy.dimshuffle(*range(y_dim_added, gy.ndim)) gy = gy.sum(axis=axis_to_sum, keepdims=True)
assert gy.broadcastable == y.broadcastable if gy.ndim != y.ndim:
assert gy.ndim > y.ndim
for i in range(y_dim_added):
assert gy.broadcastable[i]
gy = gy.dimshuffle(*range(y_dim_added, gy.ndim))
assert gy.broadcastable == y.broadcastable
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论