提交 6129495a authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #1031 from lamblin/fix_batched_dot_test

Fix test for when floatX=float32
......@@ -1897,8 +1897,8 @@ def test_batched_dot():
first = theano.tensor.tensor3("first")
second = theano.tensor.tensor3("second")
output = theano.tensor.basic.batched_dot(first, second)
first_val = numpy.random.rand(10, 10, 20)
second_val = numpy.random.rand(10, 20, 5)
first_val = numpy.random.rand(10, 10, 20).astype(config.floatX)
second_val = numpy.random.rand(10, 20, 5).astype(config.floatX)
result_fn = theano.function([first, second], output)
result = result_fn(first_val, second_val)
assert result.shape[0] == first_val.shape[0]
......@@ -1908,8 +1908,8 @@ def test_batched_dot():
first_mat = theano.tensor.dmatrix("first")
second_mat = theano.tensor.dmatrix("second")
output = theano.tensor.basic.batched_dot(first_mat, second_mat)
first_mat_val = numpy.random.rand(10, 10)
second_mat_val = numpy.random.rand(10, 10)
first_mat_val = numpy.random.rand(10, 10).astype(config.floatX)
second_mat_val = numpy.random.rand(10, 10).astype(config.floatX)
result_fn = theano.function([first_mat, second_mat], output)
result = result_fn(first_mat_val, second_mat_val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论