Unverified 提交 9f286e70 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6629 from abergeron/subtensor_grad

Don't return int from subtensor grads
......@@ -1761,7 +1761,14 @@ class AdvancedSubtensor1(Op):
rval1 = [sparse_module_ref.construct_sparse_from_list(x, gz,
ilist)]
else:
rval1 = [advanced_inc_subtensor1(x.zeros_like(), gz, ilist)]
if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=theano.config.floatX)
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
rval1 = [advanced_inc_subtensor1(gx, gz, ilist)]
return rval1 + [DisconnectedType()()] * (len(inputs) - 1)
def R_op(self, inputs, eval_points):
......@@ -2238,9 +2245,15 @@ class AdvancedSubtensor(BaseAdvancedSubtensor):
def grad(self, inputs, grads):
gz, = grads
x = inputs[0]
if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=theano.config.floatX)
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
rest = inputs[1:]
return [advanced_inc_subtensor(theano.tensor.zeros_like(x), gz,
*rest)] + \
return [advanced_inc_subtensor(gx, gz, *rest)] + \
[DisconnectedType()()] * len(rest)
advanced_subtensor = AdvancedSubtensor()
......@@ -2258,9 +2271,15 @@ class AdvancedBooleanSubtensor(BaseAdvancedSubtensor):
def grad(self, inputs, grads):
gz, = grads
x = inputs[0]
if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=theano.config.floatX)
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
rest = inputs[1:]
return [advanced_boolean_inc_subtensor(theano.tensor.zeros_like(x), gz,
*rest)] + \
return [advanced_boolean_inc_subtensor(gx, gz, *rest)] + \
[DisconnectedType()()] * len(rest)
advanced_boolean_subtensor = AdvancedBooleanSubtensor()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论