提交 ff9c4bdc authored 作者: Pascal Lamblin's avatar Pascal Lamblin

correct handling of int/complex dtypes in IncSubtensor.grad

上级 ef079e7a
...@@ -1471,6 +1471,16 @@ class IncSubtensor(Op): ...@@ -1471,6 +1471,16 @@ class IncSubtensor(Op):
x, y = inputs[:2] x, y = inputs[:2]
idx_list = inputs[2:] idx_list = inputs[2:]
if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=theano.config.floatX)
if y.dtype in theano.tensor.discrete_dtypes:
gy = y.zeros_like(dtype=theano.config.floatX)
else:
gy = y.zeros_like()
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
if self.set_instead_of_inc: if self.set_instead_of_inc:
gx = set_subtensor( gx = set_subtensor(
Subtensor(idx_list=self.idx_list)(g_output, *idx_list), Subtensor(idx_list=self.idx_list)(g_output, *idx_list),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论