提交 856aa0b6 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3267 from koningrobot/tensordot-as-dot

Implement batched_tensordot in terms of batched_dot
...@@ -1056,6 +1056,16 @@ def _scal_elemwise_with_nfunc(nfunc, nin, nout): ...@@ -1056,6 +1056,16 @@ def _scal_elemwise_with_nfunc(nfunc, nin, nout):
_scal_elemwise = _scal_elemwise_with_nfunc(None, None, None) _scal_elemwise = _scal_elemwise_with_nfunc(None, None, None)
def _pack(x):
"""
Convert x to a list if it is an iterable, otherwise wrap it in a list.
"""
try:
return list(x)
except TypeError:
return [x]
######################### #########################
# Casting Operations # Casting Operations
######################### #########################
...@@ -3357,24 +3367,11 @@ def batched_tensordot(x, y, axes=2): ...@@ -3357,24 +3367,11 @@ def batched_tensordot(x, y, axes=2):
3rd axis of b must have the same shape; the same is true for 3rd axis of b must have the same shape; the same is true for
the 3rd axis of a and the 5th axis of b. the 3rd axis of a and the 5th axis of b.
Like tensordot, this function uses a series of dimshuffles and
reshapes to reduce the tensor dot product to a matrix or vector
dot product. Finally, it calls batched_dot to compute the result.
""" """
if isinstance(axes, (list, numpy.ndarray)): return _tensordot_as_dot(x, y, axes, dot=batched_dot, batched=True)
if isinstance(axes, list):
axes = numpy.asarray(axes)
else:
axes = axes.copy()
assert numpy.greater(axes, 0).all(), (
"All axes should be greater than one, as the "
"first axis is iterated over (batch-wise scan)")
axes -= 1
result, updates = theano.scan(
fn=lambda x_mat, y_mat:
theano.tensor.tensordot(x_mat, y_mat, axes),
outputs_info=None,
sequences=[x, y],
non_sequences=None)
return result
def split(x, splits_size, n_splits, axis=0): def split(x, splits_size, n_splits, axis=0):
...@@ -5273,6 +5270,129 @@ def dot(a, b): ...@@ -5273,6 +5270,129 @@ def dot(a, b):
# Linalg : TensorDot # Linalg : TensorDot
######################### #########################
def _tensordot_as_dot(a, b, axes, dot, batched):
"""
Reduces a tensor dot product to a matrix or vector dot product. Based
on code from Tijmen Tieleman's gnumpy
(http://www.cs.toronto.edu/~tijmen/gnumpy.html).
Please see the documentation of tensordot for the meaning of the a, b
and axes arguments.
:param dot: a function that accepts two symbolic variables and computes
the appropriate dot product (e.g. dot, batched_dot)
:type dot: function
:param batched: whether to treat the first axis of a and b as a batch
axis. If so, this axis will be preserved in the output,
allowing this function to be used also for batched
tensor dot products.
:type batched: boolean
:returns: a tensor with shape equal to the concatenation of a's shape
(less any dimensions that were summed over) and b's shape
(less the first dimension and any dimensions that were summed
over).
:rtype: symbolic tensor
"""
a, b = as_tensor_variable(a), as_tensor_variable(b)
if not numpy.isscalar(axes) and len(axes) != 2:
raise ValueError('Axes should be an integer or a '
'list/tuple of len 2 (%s was provided)'
% repr(axes))
# if 'axes' is a number of axes to multiply and sum over (trailing axes
# of a, leading axes of b), we can just reshape and use dot.
elif numpy.isscalar(axes):
axes = int(axes)
for operand_name, operand in (("a", a), ("b", b)):
if axes > operand.ndim:
raise ValueError(
'axes can not be larger than the dimension of %s '
'(%s.ndim=%i, axes=%i)'
% (operand_name, operand_name, operand.ndim, axes))
if batched and axes == operand.ndim:
raise ValueError(
'axes to sum over must not include the batch axis '
'of %s (%s.ndim=%i, axes=%i)'
% (operand_name, operand_name, operand.ndim, axes))
batch_axes = 1 if batched else 0
a_outaxes = slice(0, a.ndim - axes)
b_outaxes = slice(batch_axes + axes, b.ndim)
outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]])
outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes]
outndim = len(outbcast)
a_shape = [1] * 2
b_shape = [1] * 2
# compute total size of summed axes
for i in xrange(0, axes):
a_shape[1] *= a.shape[-(i + 1)]
b_shape[0] *= b.shape[batch_axes + i]
# compute total size of other axes
for i in xrange(0, a.ndim - axes - batch_axes):
a_shape[0] *= a.shape[batch_axes + i]
for i in xrange(0, b.ndim - axes - batch_axes):
b_shape[1] *= b.shape[-(i + 1)]
if batched:
a_shape.insert(0, a.shape[0])
b_shape.insert(0, b.shape[0])
a_reshaped = a.reshape(a_shape)
b_reshaped = b.reshape(b_shape)
out_reshaped = dot(a_reshaped, b_reshaped)
out = out_reshaped.reshape(outshape, outndim)
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
return patternbroadcast(out, outbcast)
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
else:
axes = [_pack(axes_) for axes_ in axes]
if len(axes[0]) != len(axes[1]):
raise ValueError('Axes elements must have the same length.')
for i, (operand_name, operand) in enumerate((("a", a),
("b", b))):
if len(axes[i]) > operand.ndim:
raise ValueError(
'axes[%i] should be array_like with length less than '
'the dimensions of %s (%s.ndim=%i, len(axes[0])=%i).' %
(i, operand_name, operand_name, operand.ndim,
len(axes[i])))
if len(axes[i]) > 0 and numpy.max(axes[i]) >= operand.ndim:
raise ValueError(
'axes[%i] contains dimensions greater than or equal '
'to %s.ndim (%s.ndim=%i, max(axes[0])=%i).' %
(i, operand_name, operand_name, operand.ndim,
numpy.max(numpy.array(axes[i]))))
if batched and 0 in axes[i]:
raise ValueError(
'axes to sum over must not contain the batch axis '
'(axes[%i]=%s)' %
(i, axes[i]))
batch_axes = [0] if batched else []
other_axes = [[x for x in xrange(operand.ndim)
if x not in axes[i] and x not in batch_axes]
for i, operand in enumerate((a, b))]
a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0])
b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1])
# now that a and b are in the right order, recur with integer axes
return _tensordot_as_dot(a_shuffled, b_shuffled, len(axes[0]),
dot=dot, batched=batched)
def tensordot(a, b, axes=2): def tensordot(a, b, axes=2):
""" """
Compute a generalized dot product over provided axes. Compute a generalized dot product over provided axes.
...@@ -5373,108 +5493,7 @@ def tensordot(a, b, axes=2): ...@@ -5373,108 +5493,7 @@ def tensordot(a, b, axes=2):
See the documentation of numpy.tensordot for more examples. See the documentation of numpy.tensordot for more examples.
""" """
a, b = as_tensor_variable(a), as_tensor_variable(b) return _tensordot_as_dot(a, b, axes, dot=dot, batched=False)
# axes must be a scalar or list/tuple of length 2
if not numpy.isscalar(axes) and len(axes) != 2:
raise ValueError('Axes should be an integer or a '
'list/tuple of len 2 (%s was provided)' % repr(axes))
# if 'axes' is a number of axes to multiply and sum over (trailing axes
# of a, leading axes of b), we can just reshape and use dot.
elif numpy.isscalar(axes):
axes = int(axes)
# check if axes is valid given the dimension of a and b
if axes > a.ndim:
raise ValueError('axes can not be larger than the dimension of '
'a (a.ndim=%i, axes=%i)' % (a.ndim, axes))
if axes > b.ndim:
raise ValueError('axes can not be larger than than the dimension '
'of b (b.ndim=%i, axes=%i)' % (b.ndim, axes))
outshape = concatenate([a.shape[:a.ndim - axes], b.shape[axes:]])
outbcast = a.broadcastable[:a.ndim - axes] + b.broadcastable[axes:]
outndim = a.ndim + b.ndim - (2 * axes)
a_shape_0 = b_shape_0 = a_shape_1 = b_shape_1 = 1
for s0 in xrange(a.ndim - axes):
a_shape_0 *= a.shape[s0]
for s0 in xrange(axes):
b_shape_0 *= b.shape[s0]
for s1 in xrange(a.ndim - axes, a.ndim):
a_shape_1 *= a.shape[s1]
for s1 in xrange(axes, b.ndim):
b_shape_1 *= b.shape[s1]
a_reshaped = a.reshape((a_shape_0, a_shape_1), ndim=2)
b_reshaped = b.reshape((b_shape_0, b_shape_1), ndim=2)
out = _dot(a_reshaped, b_reshaped).reshape(outshape, outndim)
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
return patternbroadcast(out, outbcast)
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
else:
# get first axis element as a tuple
try:
a_axes = tuple(axes[0])
except TypeError:
a_axes = tuple([axes[0]])
# get second axis element as a tuple
try:
b_axes = tuple(axes[1])
except TypeError:
b_axes = tuple([axes[1]])
# the two axes lists must have the same length
if len(a_axes) != len(b_axes):
raise ValueError('Axes elements must have the same length.')
# check that there aren't more axes than a has dimensions
if len(a_axes) > a.ndim:
raise ValueError('axes[0] should be array_like with length '
'less than the dimensions of a '
'(a.ndim=%i, len(axes[0])=%i).' %
(a.ndim, len(a_axes)))
# check that a_axes doesn't contain an axis greater than or equal to
# a's dimensions. also check if len > 0 so numpy.max won't raise an
# error.
if len(a_axes) > 0 and numpy.max(numpy.array(a_axes)) >= a.ndim:
raise ValueError('axes[0] contains dimensions greater than or '
'equal to a.ndim (a.ndim=%i, max(axes[0])=%i).' %
(a.ndim, numpy.max(numpy.array(a_axes))))
# check that there aren't more axes than b has dimensions
if len(b_axes) > b.ndim:
raise ValueError('axes[1] should be array_like, of length '
'smaller than the dimension of b '
'(a.ndim=%i, len(axes[0])=%i).' %
(b.ndim, len(b_axes)))
# check that b_axes doesn't contain an axis greater than or equal to
# b's dimensions. also check if len > 0 so numpy.max won't raise an
# error.
if len(b_axes) > 0 and numpy.max(numpy.array(b_axes)) >= b.ndim:
raise ValueError('axes[1] contains dimensions greater than or '
'equal to b.ndim (b.ndim=%i, max(axes[1])=%i).' %
(b.ndim, numpy.max(numpy.array(b_axes))))
a_order = (tuple(x for x in tuple(xrange(a.ndim)) if x not in a_axes) +
a_axes)
b_order = (b_axes + tuple(x
for x in tuple(xrange(b.ndim))
if x not in b_axes))
a_shuffled = a.dimshuffle(a_order)
b_shuffled = b.dimshuffle(b_order)
# now that a and b are in the right order, call tensordot recursively
return tensordot(a_shuffled, b_shuffled, len(a_axes))
def outer(x, y): def outer(x, y):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论