提交 630accbc authored 作者: Caglar's avatar Caglar

Rename the batched_dot22 function to be more generic and support other tensors as well.

上级 46a16877
...@@ -4325,20 +4325,20 @@ def set_subtensor(x, y, inplace=False, ...@@ -4325,20 +4325,20 @@ def set_subtensor(x, y, inplace=False,
return inc_subtensor(x, y, inplace, set_instead_of_inc=True, return inc_subtensor(x, y, inplace, set_instead_of_inc=True,
tolerate_inplace_aliasing=tolerate_inplace_aliasing) tolerate_inplace_aliasing=tolerate_inplace_aliasing)
def batched_dot22(x, y): def batched_dot(x, y):
""" """
:param x: A 3D Tensor with sizes (dim1, dim2, dim3) :param x: A Tensor with sizes e.g.: for 3D (dim1, dim3, dim2)
:param y: A 3D Tensor with sizes (dim1, dim2, dim4) :param y: A Tensor with sizes e.g.: for 3D (dim1, dim2, dim4)
This function computes the dot product between the two 3D tensors, by iterating This function computes the dot product between the two tensors, by iterating
over the first dimension using scan. over the first dimension using scan.
Returns a 3D tensor of size (dim1, dim3, dim4) Returns a tensor of size e.g. if it is 3D: (dim1, dim3, dim4)
Example: Example:
>>> first = T.tensor3('first') >>> first = T.tensor3('first')
>>> second = T.tensor3('second') >>> second = T.tensor3('second')
>>> result = batched_dot22(first, second) >>> result = batched_dot(first, second)
""" """
result, updates = theano.scan(fn=lambda x_mat, y_mat: result, updates = theano.scan(fn=lambda x_mat, y_mat:
theano.tensor.dot(x_mat.T, y_mat), theano.tensor.dot(x_mat, y_mat),
outputs_info=None, outputs_info=None,
sequences=[x, y], sequences=[x, y],
non_sequences=None) non_sequences=None)
......
...@@ -1898,12 +1898,12 @@ def test_batched_dot22(): ...@@ -1898,12 +1898,12 @@ def test_batched_dot22():
first = theano.tensor.tensor3("first") first = theano.tensor.tensor3("first")
second = theano.tensor.tensor3("second") second = theano.tensor.tensor3("second")
output = theano.tensor.basic.batched_dot22(first, second) output = theano.tensor.basic.batched_dot22(first, second)
first_val = numpy.random.rand(10, 20, 10) first_val = numpy.random.rand(10, 10, 20)
second_val = numpy.random.rand(10, 20, 5) second_val = numpy.random.rand(10, 20, 5)
result_fn = theano.function([first, second], output) result_fn = theano.function([first, second], output)
result = result_fn(first_val, second_val) result = result_fn(first_val, second_val)
assert result.shape[0] == first_val.shape[0] assert result.shape[0] == first_val.shape[0]
assert result.shape[1] == first_val.shape[2] assert result.shape[1] == first_val.shape[1]
assert result.shape[2] == second_val.shape[2] assert result.shape[2] == second_val.shape[2]
def test_tensor_values_eq_approx(): def test_tensor_values_eq_approx():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论