提交 8054c46b authored 作者: Ian Goodfellow's avatar Ian Goodfellow 提交者: Frederic

handle case where all outputs are int

上级 c6de7352
...@@ -15,6 +15,7 @@ from theano.printing import min_informative_str, pprint ...@@ -15,6 +15,7 @@ from theano.printing import min_informative_str, pprint
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
from theano.tensor.utils import hash_from_dict from theano.tensor.utils import hash_from_dict
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.gof.null_type import NullType
config = theano.config config = theano.config
...@@ -638,9 +639,42 @@ class Elemwise(Op): ...@@ -638,9 +639,42 @@ class Elemwise(Op):
def grad(self, inputs, ograds): def grad(self, inputs, ograds):
outs = self(*inputs)
if not isinstance(outs, (list,tuple)):
outs = [ outs ]
#compute grad with respect to broadcasted input #compute grad with respect to broadcasted input
rval = self._bgrad(inputs, ograds) rval = self._bgrad(inputs, ograds)
# TODO: make sure that zeros are clearly identifiable
# to the gradient.grad method when the outputs have
# some integer and some floating point outputs
if False in [str(out.type.dtype).find('int') == -1
for out in outs]:
# For integer output, return value may
# only be zero or undefined
# We don't bother with trying to check
# that the scalar ops correctly
# returned something that evaluates to 0,
# we just make the return
# value obviously zero so that gradient.grad
# can tell this op did
# the right thing.
new_rval = []
for elem, ipt in zip(rval, inputs):
if isinstance(elem.type, (NullType, DisconnectedType)):
new_rval.append(elem)
else:
elem = ipt.zeros_like()
if str(elem.type.dtype).find('int') != -1:
elem = elem.astype(theano.config.floatX)
assert str(elem.type.dtype).find('int') == -1
new_rval.append(elem)
return new_rval
#sum out the broadcasted dimensions #sum out the broadcasted dimensions
for i, ipt in enumerate(inputs): for i, ipt in enumerate(inputs):
if rval[i] is None: if rval[i] is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论