提交 97d24a45 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

better handling of disconnected grads in elemwise

上级 50acd143
......@@ -722,20 +722,19 @@ class Elemwise(Op):
def _bgrad(self, inputs, ograds):
# returns grad, with respect to broadcasted versions of inputs
# Gradients (especially on the final costs) don't have to be symbolic
# e.g., ograds will be [ 1. ] if your objective is c and the output
# of the current apply node is c
ograds = map(as_tensor_variable, ograds)
prev_setting = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'off'
scalar_inputs = [Scalar(dtype=t.type.dtype)() for t in inputs]
scalar_ograds = [Scalar(dtype=ograd.type.dtype)()
for ograd in ograds]
def as_scalar(t):
if isinstance(t.type, (NullType, DisconnectedType)):
return t
return Scalar(t.type.dtype)()
scalar_inputs = map(as_scalar, inputs)
scalar_ograds = map(as_scalar, ograds)
scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds)
for igrad in scalar_igrads:
assert igrad is not None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论