提交 630accbc authored 作者: Caglar's avatar Caglar

Rename the batched_dot22 function to be more generic and support other tensors as well.

上级 46a16877
......@@ -4325,20 +4325,20 @@ def set_subtensor(x, y, inplace=False,
return inc_subtensor(x, y, inplace, set_instead_of_inc=True,
tolerate_inplace_aliasing=tolerate_inplace_aliasing)
def batched_dot22(x, y):
def batched_dot(x, y):
"""
:param x: A 3D Tensor with sizes (dim1, dim2, dim3)
:param y: A 3D Tensor with sizes (dim1, dim2, dim4)
This function computes the dot product between the two 3D tensors, by iterating
:param x: A Tensor with sizes e.g.: for 3D (dim1, dim3, dim2)
:param y: A Tensor with sizes e.g.: for 3D (dim1, dim2, dim4)
This function computes the dot product between the two tensors, by iterating
over the first dimension using scan.
Returns a 3D tensor of size (dim1, dim3, dim4)
Returns a tensor of size e.g. if it is 3D: (dim1, dim3, dim4)
Example:
>>> first = T.tensor3('first')
>>> second = T.tensor3('second')
>>> result = batched_dot22(first, second)
>>> result = batched_dot(first, second)
"""
result, updates = theano.scan(fn=lambda x_mat, y_mat:
theano.tensor.dot(x_mat.T, y_mat),
theano.tensor.dot(x_mat, y_mat),
outputs_info=None,
sequences=[x, y],
non_sequences=None)
......
......@@ -1898,12 +1898,12 @@ def test_batched_dot22():
first = theano.tensor.tensor3("first")
second = theano.tensor.tensor3("second")
output = theano.tensor.basic.batched_dot22(first, second)
first_val = numpy.random.rand(10, 20, 10)
first_val = numpy.random.rand(10, 10, 20)
second_val = numpy.random.rand(10, 20, 5)
result_fn = theano.function([first, second], output)
result = result_fn(first_val, second_val)
assert result.shape[0] == first_val.shape[0]
assert result.shape[1] == first_val.shape[2]
assert result.shape[1] == first_val.shape[1]
assert result.shape[2] == second_val.shape[2]
def test_tensor_values_eq_approx():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论