提交 980e91d9 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1417 from nicholas-leonard/master

theano.tensor.batched_tensordot(x, y, axes)
...@@ -1452,7 +1452,7 @@ Linear Algebra ...@@ -1452,7 +1452,7 @@ Linear Algebra
print(b.shape) #(5,6,4,3) print(b.shape) #(5,6,4,3)
print(c.shape) #(2,3,4,5,6,4,3) print(c.shape) #(2,3,4,5,6,4,3)
See the documentation of numpy.tensordot for more examples. :note: See the documentation of `numpy.tensordot <http://docs.scipy.org/doc/numpy/reference/generated/numpy.tensordot.html>`_ for more examples.
.. function:: batched_dot(X, Y) .. function:: batched_dot(X, Y)
...@@ -1478,6 +1478,40 @@ Linear Algebra ...@@ -1478,6 +1478,40 @@ Linear Algebra
:return: tensor of products :return: tensor of products
.. function:: batched_tensordot(X, Y, axes=2)
:param x: A Tensor with sizes e.g.: for 3D (dim1, dim3, dim2)
:param y: A Tensor with sizes e.g.: for 3D (dim1, dim2, dim4)
:param axes: an integer or array. If an integer, the number of axes
to sum over. If an array, it must have two array
elements containing the axes to sum over in each tensor.
If an integer i, it is converted to an array containing
the last i dimensions of the first tensor and the first
i dimensions of the second tensor (excluding the first
(batch) dimension):
axes = [range(a.ndim - i, b.ndim), range(1,i+1)]
If an array, its two elements must contain compatible axes
of the two tensors. For example, [[1, 2], [2, 4]] means sum
over the 2nd and 3rd axes of a and the 3rd and 5th axes of b.
(Remember axes are zero-indexed!) The 2nd axis of a and the
3rd axis of b must have the same shape; the same is true for
the 3rd axis of a and the 5th axis of b.
:type axes: int or array-like of length 2
:returns: a tensor with shape equal to the concatenation of a's shape
(less any dimensions that were summed over) and b's shape
(less first dimension and any dimensions that were summed over).
:rtype: tensor of tensordots
A hybrid of batch_dot and tensordot, this function computes the
tensordot product between the two tensors, by iterating over the
first dimension using scan to perform a sequence of tensordots.
:note: See :func:`tensordot` and :func:`batched_dot` for
supplementary documentation.
Gradient / Differentiation Gradient / Differentiation
......
...@@ -3106,6 +3106,48 @@ def batched_dot(x, y): ...@@ -3106,6 +3106,48 @@ def batched_dot(x, y):
return result return result
def batched_tensordot(x, y, axes=2):
"""
:param x: A Tensor with sizes e.g.: for 3D (dim1, dim3, dim2)
:param y: A Tensor with sizes e.g.: for 3D (dim1, dim2, dim4)
:param axes: an integer or array. If an integer, the number of axes
to sum over. If an array, it must have two array
elements containing the axes to sum over in each tensor.
If an integer i, it is converted to an array containing
the last i dimensions of the first tensor and the first
i dimensions of the second tensor (excluding the first
(batch) dimension):
axes = [range(a.ndim - i, b.ndim), range(1,i+1)]
If an array, its two elements must contain compatible axes
of the two tensors. For example, [[1, 2], [2, 4]] means sum
over the 2nd and 3rd axes of a and the 3rd and 5th axes of b.
(Remember axes are zero-indexed!) The 2nd axis of a and the
3rd axis of b must have the same shape; the same is true for
the 3rd axis of a and the 5th axis of b.
:type axes: int or array-like of length 2
A hybrid of batch_dot and tensordot, this function computes the
tensordot product between the two tensors, by iterating over the
first dimension using scan to perform a sequence of tensordots.
"""
if isinstance(axes, (list, numpy.ndarray)):
if isinstance(axes, list):
axes = numpy.asarray(axes)
else:
axes = axes.copy()
assert numpy.greater(axes,0).all(), "All axes should be greater than one, as the first axis is iterated over (batch-wise scan)"
axes -= 1
result, updates = theano.scan(fn=lambda x_mat, y_mat:
theano.tensor.tensordot(x_mat, y_mat, axes),
outputs_info=None,
sequences=[x, y],
non_sequences=None)
return result
def split(x, splits_size, n_splits, axis=0): def split(x, splits_size, n_splits, axis=0):
the_split = Split(n_splits) the_split = Split(n_splits)
return the_split(x, axis, splits_size) return the_split(x, axis, splits_size)
......
...@@ -2323,8 +2323,32 @@ def test_batched_dot(): ...@@ -2323,8 +2323,32 @@ def test_batched_dot():
result_fn = theano.function([first_mat, second_mat], output) result_fn = theano.function([first_mat, second_mat], output)
result = result_fn(first_mat_val, second_mat_val) result = result_fn(first_mat_val, second_mat_val)
assert result.shape[0] == first_mat_val.shape[0]
def test_batched_tensordot():
first = theano.tensor.tensor4("first")
second = theano.tensor.tensor4("second")
axes = [[1,2], [3,1]]
output = theano.tensor.basic.batched_tensordot(first, second, axes)
first_val = numpy.random.rand(8, 10, 20, 3).astype(config.floatX)
second_val = numpy.random.rand(8, 20, 5, 10).astype(config.floatX)
result_fn = theano.function([first, second], output)
result = result_fn(first_val, second_val)
assert result.shape[0] == first_val.shape[0] assert result.shape[0] == first_val.shape[0]
assert result.shape[1] == first_val.shape[3]
assert result.shape[2] == second_val.shape[2]
first_mat = theano.tensor.dmatrix("first")
second_mat = theano.tensor.dmatrix("second")
axes = 1
output = theano.tensor.basic.batched_tensordot(first_mat, second_mat, axes)
first_mat_val = numpy.random.rand(10, 4).astype(config.floatX)
second_mat_val = numpy.random.rand(10, 4).astype(config.floatX)
result_fn = theano.function([first_mat, second_mat], output)
result = result_fn(first_mat_val, second_mat_val)
print(result.shape)
assert result.shape[0] == first_mat_val.shape[0]
assert len(result.shape) == 1
def test_tensor_values_eq_approx(): def test_tensor_values_eq_approx():
#test, inf, -inf and nan equal themself #test, inf, -inf and nan equal themself
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论