提交 0beeea9e authored 作者: Frederic's avatar Frederic

Fix gh-2210

上级 4829418f
...@@ -3474,7 +3474,10 @@ class Join(Op): ...@@ -3474,7 +3474,10 @@ class Join(Op):
rval = [grad_undefined(self, 0, axis)] rval = [grad_undefined(self, 0, axis)]
if 'float' in tensors[0].dtype or 'complex' in tensors[0].dtype: dtypes = [as_tensor_variable(x).type.dtype for x in tensors]
out_dtype = scal.upcast(*dtypes)
if 'float' in out_dtype or 'complex' in out_dtype:
# assume that this is differentiable # assume that this is differentiable
split = Split(len(tensors)) split = Split(len(tensors))
split_gz = split(gz, axis, stack(*[shape(x)[axis] split_gz = split(gz, axis, stack(*[shape(x)[axis]
...@@ -3491,6 +3494,7 @@ class Join(Op): ...@@ -3491,6 +3494,7 @@ class Join(Op):
else: else:
# the output has integer type, so the gradient through it # the output has integer type, so the gradient through it
# is 0 # is 0
# TODO what should be there?
rval = rval + [tensor.zeros_like() for tensor in tensors] rval = rval + [tensor.zeros_like() for tensor in tensors]
return rval return rval
......
...@@ -3290,6 +3290,22 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3290,6 +3290,22 @@ class T_Join_and_Split(unittest.TestCase):
utt.verify_grad(lambda a, b: join(1, a, b), [av, bv], utt.verify_grad(lambda a, b: join(1, a, b), [av, bv],
eps=1.0e-4, rel_tol=1.0e-3, mode=self.mode) eps=1.0e-4, rel_tol=1.0e-3, mode=self.mode)
def test_join_matrix_dtypes(self):
# Test mixed dtype. There was a bug that caused crash in the past.
av = numpy.array([[1, 2, 3], [4, 5, 6]], dtype='int8')
bv = numpy.array([[7], [8]], dtype='float32')
a = self.shared(av)
b = as_tensor_variable(bv)
s = join(1, a, b)
want = numpy.array([[1, 2, 3, 7], [4, 5, 6, 8]], dtype='float32')
out = self.eval_outputs_and_check_join([s])
self.assertTrue((out == want).all())
grad(s.sum(), b)
grad(s.sum(), a)
utt.verify_grad(lambda b: join(1, a, b), [bv],
eps=1.0e-4, rel_tol=1.0e-3, mode=self.mode)
def test_join_matrix1_using_vertical_stack(self): def test_join_matrix1_using_vertical_stack(self):
a = self.shared(numpy.array([[1, 2, 3], [4, 5, 6]], dtype=self.floatX)) a = self.shared(numpy.array([[1, 2, 3], [4, 5, 6]], dtype=self.floatX))
b = as_tensor_variable(numpy.array([[7, 8, 9]], dtype=self.floatX)) b = as_tensor_variable(numpy.array([[7, 8, 9]], dtype=self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论