提交 2566bedd authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5175 from abergeron/fix_elem_grad_bool

Fix grad of elemwise with boolean inputs.
...@@ -9,6 +9,7 @@ from six.moves import xrange ...@@ -9,6 +9,7 @@ from six.moves import xrange
import theano import theano
from theano import gof from theano import gof
from theano.compat import izip from theano.compat import izip
from theano.configparser import change_flags
from theano.gof import Apply, Op, OpenMPOp from theano.gof import Apply, Op, OpenMPOp
from theano import scalar from theano import scalar
from theano.scalar import get_scalar_type from theano.scalar import get_scalar_type
...@@ -667,8 +668,8 @@ second dimension ...@@ -667,8 +668,8 @@ second dimension
# TODO: make sure that zeros are clearly identifiable # TODO: make sure that zeros are clearly identifiable
# to the gradient.grad method when the outputs have # to the gradient.grad method when the outputs have
# some integer and some floating point outputs # some integer and some floating point outputs
if False in [str(out.type.dtype).find('int') == -1 if any(out.type.dtype not in theano.tensor.continuous_dtypes
for out in outs]: for out in outs):
# For integer output, return value may # For integer output, return value may
# only be zero or undefined # only be zero or undefined
# We don't bother with trying to check # We don't bother with trying to check
...@@ -684,7 +685,7 @@ second dimension ...@@ -684,7 +685,7 @@ second dimension
new_rval.append(elem) new_rval.append(elem)
else: else:
elem = ipt.zeros_like() elem = ipt.zeros_like()
if str(elem.type.dtype).find('int') != -1: if str(elem.type.dtype) not in theano.tensor.continuous_dtypes:
elem = elem.astype(theano.config.floatX) elem = elem.astype(theano.config.floatX)
assert str(elem.type.dtype).find('int') == -1 assert str(elem.type.dtype).find('int') == -1
new_rval.append(elem) new_rval.append(elem)
...@@ -724,12 +725,7 @@ second dimension ...@@ -724,12 +725,7 @@ second dimension
def _bgrad(self, inputs, ograds): def _bgrad(self, inputs, ograds):
# returns grad, with respect to broadcasted versions of inputs # returns grad, with respect to broadcasted versions of inputs
prev_setting = theano.config.compute_test_value with change_flags(compute_test_value='off'):
try:
theano.config.compute_test_value = 'off'
def as_scalar(t): def as_scalar(t):
if isinstance(t.type, (NullType, DisconnectedType)): if isinstance(t.type, (NullType, DisconnectedType)):
return t return t
...@@ -741,10 +737,6 @@ second dimension ...@@ -741,10 +737,6 @@ second dimension
for igrad in scalar_igrads: for igrad in scalar_igrads:
assert igrad is not None, self.scalar_op assert igrad is not None, self.scalar_op
finally:
theano.config.compute_test_value = prev_setting
if not isinstance(scalar_igrads, (list, tuple)): if not isinstance(scalar_igrads, (list, tuple)):
raise TypeError('%s.grad returned %s instead of list or tuple' % raise TypeError('%s.grad returned %s instead of list or tuple' %
(str(self.scalar_op), str(type(scalar_igrads)))) (str(self.scalar_op), str(type(scalar_igrads))))
......
...@@ -1154,6 +1154,11 @@ class TestBitOpReduceGrad(unittest.TestCase): ...@@ -1154,6 +1154,11 @@ class TestBitOpReduceGrad(unittest.TestCase):
class TestElemwise(unittest_tools.InferShapeTester): class TestElemwise(unittest_tools.InferShapeTester):
def test_elemwise_grad_bool(self):
x = theano.tensor.scalar(dtype='bool')
y = theano.tensor.bscalar()
z = x * y
dx, dy = theano.grad(z, [x, y])
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论