提交 46a16877 authored 作者: Caglar's avatar Caglar

Added the batched_dot22 test and fix a few bugs in batched_dot22.

上级 33b8236f
......@@ -4337,9 +4337,10 @@ def batched_dot22(x, y):
>>> second = T.tensor3('second')
>>> result = batched_dot22(first, second)
"""
result, updates = theano.scan(fn=lambda x_mat, y_mat: T.dot(x_mat.T, y_mat),
result, updates = theano.scan(fn=lambda x_mat, y_mat:
theano.tensor.dot(x_mat.T, y_mat),
outputs_info=None,
sequences=[x_mat, y_mat],
sequences=[x, y],
non_sequences=None)
return result
......
......@@ -1894,6 +1894,17 @@ def _approx_eq(a, b, eps=1.0e-4):
return True
_approx_eq.debug = 0
def test_batched_dot22():
first = theano.tensor.tensor3("first")
second = theano.tensor.tensor3("second")
output = theano.tensor.basic.batched_dot22(first, second)
first_val = numpy.random.rand(10, 20, 10)
second_val = numpy.random.rand(10, 20, 5)
result_fn = theano.function([first, second], output)
result = result_fn(first_val, second_val)
assert result.shape[0] == first_val.shape[0]
assert result.shape[1] == first_val.shape[2]
assert result.shape[2] == second_val.shape[2]
def test_tensor_values_eq_approx():
#test, inf, -inf and nan equal themself
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论