提交 46a16877 authored 作者: Caglar's avatar Caglar

Added the batched_dot22 test and fix a few bugs in batched_dot22.

上级 33b8236f
...@@ -4337,9 +4337,10 @@ def batched_dot22(x, y): ...@@ -4337,9 +4337,10 @@ def batched_dot22(x, y):
>>> second = T.tensor3('second') >>> second = T.tensor3('second')
>>> result = batched_dot22(first, second) >>> result = batched_dot22(first, second)
""" """
result, updates = theano.scan(fn=lambda x_mat, y_mat: T.dot(x_mat.T, y_mat), result, updates = theano.scan(fn=lambda x_mat, y_mat:
theano.tensor.dot(x_mat.T, y_mat),
outputs_info=None, outputs_info=None,
sequences=[x_mat, y_mat], sequences=[x, y],
non_sequences=None) non_sequences=None)
return result return result
......
...@@ -1894,6 +1894,17 @@ def _approx_eq(a, b, eps=1.0e-4): ...@@ -1894,6 +1894,17 @@ def _approx_eq(a, b, eps=1.0e-4):
return True return True
_approx_eq.debug = 0 _approx_eq.debug = 0
def test_batched_dot22():
first = theano.tensor.tensor3("first")
second = theano.tensor.tensor3("second")
output = theano.tensor.basic.batched_dot22(first, second)
first_val = numpy.random.rand(10, 20, 10)
second_val = numpy.random.rand(10, 20, 5)
result_fn = theano.function([first, second], output)
result = result_fn(first_val, second_val)
assert result.shape[0] == first_val.shape[0]
assert result.shape[1] == first_val.shape[2]
assert result.shape[2] == second_val.shape[2]
def test_tensor_values_eq_approx(): def test_tensor_values_eq_approx():
#test, inf, -inf and nan equal themself #test, inf, -inf and nan equal themself
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论