提交 592e1501 authored 作者: Frederic's avatar Frederic

Fix Join.grad when all inputs/outputs dtype are ints.

上级 0beeea9e
...@@ -3494,8 +3494,8 @@ class Join(Op): ...@@ -3494,8 +3494,8 @@ class Join(Op):
else: else:
# the output has integer type, so the gradient through it # the output has integer type, so the gradient through it
# is 0 # is 0
# TODO what should be there? rval = rval + [tensor.zeros_like(dtype=config.floatX)
rval = rval + [tensor.zeros_like() for tensor in tensors] for tensor in tensors]
return rval return rval
......
...@@ -3306,6 +3306,20 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3306,6 +3306,20 @@ class T_Join_and_Split(unittest.TestCase):
utt.verify_grad(lambda b: join(1, a, b), [bv], utt.verify_grad(lambda b: join(1, a, b), [bv],
eps=1.0e-4, rel_tol=1.0e-3, mode=self.mode) eps=1.0e-4, rel_tol=1.0e-3, mode=self.mode)
def test_join_matrix_ints(self):
# Test mixed dtype. There was a bug that caused crash in the past.
av = numpy.array([[1, 2, 3], [4, 5, 6]], dtype='int8')
bv = numpy.array([[7], [8]], dtype='int32')
a = self.shared(av)
b = as_tensor_variable(bv)
s = join(1, a, b)
want = numpy.array([[1, 2, 3, 7], [4, 5, 6, 8]], dtype='float32')
out = self.eval_outputs_and_check_join([s])
self.assertTrue((out == want).all())
assert (grad(s.sum(), b).eval() == 0).all()
assert (grad(s.sum(), a).eval() == 0).all()
def test_join_matrix1_using_vertical_stack(self): def test_join_matrix1_using_vertical_stack(self):
a = self.shared(numpy.array([[1, 2, 3], [4, 5, 6]], dtype=self.floatX)) a = self.shared(numpy.array([[1, 2, 3], [4, 5, 6]], dtype=self.floatX))
b = as_tensor_variable(numpy.array([[7, 8, 9]], dtype=self.floatX)) b = as_tensor_variable(numpy.array([[7, 8, 9]], dtype=self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论