提交 4bc20eb1 authored 作者: Frederic's avatar Frederic

fix tests related to verify_grad

上级 8b437c23
......@@ -3481,7 +3481,7 @@ class T_Join_and_Split(unittest.TestCase):
def test_join_matrixV(self):
"""variable join axis"""
v = numpy.array([[1., 2., 3.], [4., 5., 6.]], dtype=self.floatX)
v = numpy.array([[.1, .2, .3], [.4, .5, .6]], dtype=self.floatX)
a = self.shared(v.copy())
b = as_tensor_variable(v.copy())
ax = lscalar()
......@@ -3491,13 +3491,15 @@ class T_Join_and_Split(unittest.TestCase):
topo = f.maker.fgraph.toposort()
assert [True for node in topo if isinstance(node.op, self.join_op)]
want = numpy.array([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])
want = numpy.array([[.1, .2, .3], [.4, .5, .6],
[.1, .2, .3], [.4, .5, .6]])
got = f(0)
self.assertTrue((got == want).all(), (got, want))
assert numpy.allclose(got, want)
want = numpy.array([[1, 2, 3, 1, 2, 3], [4, 5, 6, 4, 5, 6]])
want = numpy.array([[.1, .2, .3, .1, .2, .3],
[.4, .5, .6, .4, .5, .6]])
got = f(1)
self.assertTrue((got == want).all(), (got, want))
assert numpy.allclose(got, want)
utt.verify_grad(lambda a, b: join(0, a, b), [v, 2 * v], mode=self.mode)
utt.verify_grad(lambda a, b: join(1, a, b), [v, 2 * v], mode=self.mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论