提交 33b8236f authored 作者: Caglar's avatar Caglar

Added the batch_dot22 function.

上级 66126fc3
......@@ -4325,6 +4325,23 @@ def set_subtensor(x, y, inplace=False,
return inc_subtensor(x, y, inplace, set_instead_of_inc=True,
tolerate_inplace_aliasing=tolerate_inplace_aliasing)
def batched_dot22(x, y):
"""
:param x: A 3D Tensor with sizes (dim1, dim2, dim3)
:param y: A 3D Tensor with sizes (dim1, dim2, dim4)
This function computes the dot product between the two 3D tensors, by iterating
over the first dimension using scan.
Returns a 3D tensor of size (dim1, dim3, dim4)
Example:
>>> first = T.tensor3('first')
>>> second = T.tensor3('second')
>>> result = batched_dot22(first, second)
"""
result, updates = theano.scan(fn=lambda x_mat, y_mat: T.dot(x_mat.T, y_mat),
outputs_info=None,
sequences=[x_mat, y_mat],
non_sequences=None)
return result
def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
tolerate_inplace_aliasing=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论