提交 71e615c9 authored 作者: Caglar's avatar Caglar

Updated the test function for the batched dot operation.

上级 630accbc
......@@ -1894,10 +1894,10 @@ def _approx_eq(a, b, eps=1.0e-4):
return True
_approx_eq.debug = 0
def test_batched_dot22():
def test_batched_dot():
first = theano.tensor.tensor3("first")
second = theano.tensor.tensor3("second")
output = theano.tensor.basic.batched_dot22(first, second)
output = theano.tensor.basic.batched_dot(first, second)
first_val = numpy.random.rand(10, 10, 20)
second_val = numpy.random.rand(10, 20, 5)
result_fn = theano.function([first, second], output)
......@@ -1906,6 +1906,16 @@ def test_batched_dot22():
assert result.shape[1] == first_val.shape[1]
assert result.shape[2] == second_val.shape[2]
first_mat = theano.tensor.dmatrix("first")
second_mat = theano.tensor.dmatrix("second")
output = theano.tensor.basic.batched_dot(first_mat, second_mat)
first_mat_val = numpy.random.rand(10, 10)
second_mat_val = numpy.random.rand(10, 10)
result_fn = theano.function([first_mat, second_mat], output)
result = result_fn(first_mat_val, second_mat_val)
assert result.shape[0] == first_val.shape[0]
def test_tensor_values_eq_approx():
#test, inf, -inf and nan equal themself
a = numpy.asarray([-numpy.inf, -1, 0, 1, numpy.inf, numpy.nan])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论