提交 49d2b3f9 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2213 from nouiz/join_crash

Fix gh-2210, Join.grad crash with first input dtype being *int*
...@@ -3474,7 +3474,10 @@ class Join(Op): ...@@ -3474,7 +3474,10 @@ class Join(Op):
rval = [grad_undefined(self, 0, axis)] rval = [grad_undefined(self, 0, axis)]
if 'float' in tensors[0].dtype or 'complex' in tensors[0].dtype: dtypes = [as_tensor_variable(x).type.dtype for x in tensors]
out_dtype = scal.upcast(*dtypes)
if 'float' in out_dtype or 'complex' in out_dtype:
# assume that this is differentiable # assume that this is differentiable
split = Split(len(tensors)) split = Split(len(tensors))
split_gz = split(gz, axis, stack(*[shape(x)[axis] split_gz = split(gz, axis, stack(*[shape(x)[axis]
...@@ -3491,7 +3494,8 @@ class Join(Op): ...@@ -3491,7 +3494,8 @@ class Join(Op):
else: else:
# the output has integer type, so the gradient through it # the output has integer type, so the gradient through it
# is 0 # is 0
rval = rval + [tensor.zeros_like() for tensor in tensors] rval = rval + [tensor.zeros_like(dtype=config.floatX)
for tensor in tensors]
return rval return rval
......
...@@ -3290,6 +3290,36 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3290,6 +3290,36 @@ class T_Join_and_Split(unittest.TestCase):
utt.verify_grad(lambda a, b: join(1, a, b), [av, bv], utt.verify_grad(lambda a, b: join(1, a, b), [av, bv],
eps=1.0e-4, rel_tol=1.0e-3, mode=self.mode) eps=1.0e-4, rel_tol=1.0e-3, mode=self.mode)
def test_join_matrix_dtypes(self):
# Test mixed dtype. There was a bug that caused crash in the past.
av = numpy.array([[1, 2, 3], [4, 5, 6]], dtype='int8')
bv = numpy.array([[7], [8]], dtype='float32')
a = self.shared(av)
b = as_tensor_variable(bv)
s = join(1, a, b)
want = numpy.array([[1, 2, 3, 7], [4, 5, 6, 8]], dtype='float32')
out = self.eval_outputs_and_check_join([s])
self.assertTrue((out == want).all())
grad(s.sum(), b)
grad(s.sum(), a)
utt.verify_grad(lambda b: join(1, a, b), [bv],
eps=1.0e-4, rel_tol=1.0e-3, mode=self.mode)
def test_join_matrix_ints(self):
# Test mixed dtype. There was a bug that caused crash in the past.
av = numpy.array([[1, 2, 3], [4, 5, 6]], dtype='int8')
bv = numpy.array([[7], [8]], dtype='int32')
a = self.shared(av)
b = as_tensor_variable(bv)
s = join(1, a, b)
want = numpy.array([[1, 2, 3, 7], [4, 5, 6, 8]], dtype='float32')
out = self.eval_outputs_and_check_join([s])
self.assertTrue((out == want).all())
assert (grad(s.sum(), b).eval() == 0).all()
assert (grad(s.sum(), a).eval() == 0).all()
def test_join_matrix1_using_vertical_stack(self): def test_join_matrix1_using_vertical_stack(self):
a = self.shared(numpy.array([[1, 2, 3], [4, 5, 6]], dtype=self.floatX)) a = self.shared(numpy.array([[1, 2, 3], [4, 5, 6]], dtype=self.floatX))
b = as_tensor_variable(numpy.array([[7, 8, 9]], dtype=self.floatX)) b = as_tensor_variable(numpy.array([[7, 8, 9]], dtype=self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论