提交 730b3cdb authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Don't return int from subtensor grads

上级 49c8b94e
...@@ -1761,7 +1761,14 @@ class AdvancedSubtensor1(Op): ...@@ -1761,7 +1761,14 @@ class AdvancedSubtensor1(Op):
rval1 = [sparse_module_ref.construct_sparse_from_list(x, gz, rval1 = [sparse_module_ref.construct_sparse_from_list(x, gz,
ilist)] ilist)]
else: else:
rval1 = [advanced_inc_subtensor1(x.zeros_like(), gz, ilist)] if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=theano.config.floatX)
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
rval1 = [advanced_inc_subtensor1(gx, gz, ilist)]
return rval1 + [DisconnectedType()()] * (len(inputs) - 1) return rval1 + [DisconnectedType()()] * (len(inputs) - 1)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -2238,9 +2245,15 @@ class AdvancedSubtensor(BaseAdvancedSubtensor): ...@@ -2238,9 +2245,15 @@ class AdvancedSubtensor(BaseAdvancedSubtensor):
def grad(self, inputs, grads): def grad(self, inputs, grads):
gz, = grads gz, = grads
x = inputs[0] x = inputs[0]
if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=theano.config.floatX)
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
rest = inputs[1:] rest = inputs[1:]
return [advanced_inc_subtensor(theano.tensor.zeros_like(x), gz, return [advanced_inc_subtensor(gx, gz, *rest)] + \
*rest)] + \
[DisconnectedType()()] * len(rest) [DisconnectedType()()] * len(rest)
advanced_subtensor = AdvancedSubtensor() advanced_subtensor = AdvancedSubtensor()
...@@ -2258,9 +2271,15 @@ class AdvancedBooleanSubtensor(BaseAdvancedSubtensor): ...@@ -2258,9 +2271,15 @@ class AdvancedBooleanSubtensor(BaseAdvancedSubtensor):
def grad(self, inputs, grads): def grad(self, inputs, grads):
gz, = grads gz, = grads
x = inputs[0] x = inputs[0]
if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=theano.config.floatX)
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
rest = inputs[1:] rest = inputs[1:]
return [advanced_boolean_inc_subtensor(theano.tensor.zeros_like(x), gz, return [advanced_boolean_inc_subtensor(gx, gz, *rest)] + \
*rest)] + \
[DisconnectedType()()] * len(rest) [DisconnectedType()()] * len(rest)
advanced_boolean_subtensor = AdvancedBooleanSubtensor() advanced_boolean_subtensor = AdvancedBooleanSubtensor()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论