提交 5463e79d authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

replaced TensorDot and TensorDotGrad Ops with a function

上级 244d529f
...@@ -7109,184 +7109,25 @@ pprint.assign(dot, printing.OperatorPrinter(printing.special['middle_dot'], ...@@ -7109,184 +7109,25 @@ pprint.assign(dot, printing.OperatorPrinter(printing.special['middle_dot'],
######################### #########################
# Linalg : TensorDot # Linalg : TensorDot
######################### #########################
class TensorDotGrad(Op): # TODO: tensordot should be function as described in rst docs.
def __init__(self, axes):
self.axes = TensorDot.parse_axes(axes)
if isinstance(self.axes, (tuple, list)) and len(self.axes) == 2:
# The current perform don't implement correctly those cases
for i in range(len(self.axes[0]) - 1):
if self.axes[0][i] > self.axes[0][i + 1]:
raise NotImplementedError()
if self.axes[1][i] > self.axes[1][i + 1]:
raise NotImplementedError()
def __eq__(self, other):
return type(self) == type(other) and self.axes == other.axes
def __hash__(self):
return hashtype(self) ^ hash(self.axes) ^ 89234
def make_node(self, x, y, gz):
assert isinstance(x, Variable)
assert isinstance(y, Variable)
assert isinstance(gz, Variable)
gx = tensor(dtype=scal.upcast(gz.dtype, y.dtype),
broadcastable=x.broadcastable)
gy = tensor(dtype=scal.upcast(x.dtype, gz.dtype),
broadcastable=y.broadcastable)
op = self
if isinstance(self.axes, int):
axes = [range(x.ndim - self.axes, x.ndim), range(self.axes)]
op = TensorDotGrad(axes)
return Apply(op, [x, y, gz], [gx, gy])
def perform(self, node, inp, out):
x, y, gz = inp
gx, gy = out
sum_over_y = range(y.ndim)
[sum_over_y.remove(q) for q in self.axes[1]]
sum_over_x = range(x.ndim)
[sum_over_x.remove(q) for q in self.axes[0]]
tdot_axes = [range(x.ndim - len(self.axes[0]), gz.ndim), sum_over_y]
_gx = numpy.tensordot(gz, y, tdot_axes)
idx = numpy.hstack((sum_over_x, self.axes[0]))
newshapex = numpy.zeros(x.ndim)
newshapex[[newpos for newpos in idx]] = range(x.ndim)
gx[0] = numpy.transpose(_gx, newshapex)
tdot_axes = [sum_over_x, range(x.ndim - len(self.axes[0]))]
_gy = numpy.tensordot(x, gz, tdot_axes)
idy = numpy.hstack((self.axes[1], sum_over_y))
newshapey = numpy.zeros(y.ndim)
newshapey[[newpos for newpos in idy]] = range(y.ndim)
gy[0] = numpy.transpose(_gy, newshapey)
assert gy[0].shape == y.shape
assert gx[0].shape == x.shape
def infer_shape(self, node, in_shapes):
return in_shapes[:2]
tensordot_grad = TensorDotGrad
class TensorDot(Op):
"""
Compute tensor-tensor dot products along specified axes.
Given two tensors A and B, TensorDot takes the product of elements
in A and B and sums that result over the provided axes.
See documentation for theano.tensor.tensordot for more detail,
or the NumPy documentation at:
http://docs.scipy.org/doc/numpy/reference/generated/numpy.tensordot.html
"""
@classmethod
def parse_axes(cls, axes):
if not numpy.isscalar(axes) and len(axes) != 2:
raise ValueError("Axes should be scalar valued or a list/tuple of "
"len 2.")
if isinstance(axes, (list, tuple)):
axes_out = []
# cast axes[0] and axes[1] to tuples
for i, a in enumerate(axes):
if numpy.isscalar(a):
axes_out.append((a,))
else:
axes_out.append(tuple(a))
# these should be of same length
if len(axes_out[0]) != len(axes_out[1]):
raise ValueError("Elements of the axes list/tuple need to be "
"of the same size.")
axes = tuple(axes_out)
return axes
def __init__(self, axes):
self.axes = self.parse_axes(axes)
def __eq__(self, other):
return type(self) == type(other) and self.axes == other.axes
def __hash__(self):
return hashtype(self) ^ hash(self.axes) ^ 89234
def make_node(self, x, y):
op = self
if isinstance(self.axes, int):
axes = [range(x.ndim - self.axes, x.ndim), range(self.axes)]
op = TensorDot(axes)
axesdim = numpy.size(op.axes) / 2
x, y = map(as_tensor_variable, [x, y])
if axesdim > x.type.ndim or axesdim > y.type.ndim:
raise TypeError('Cannot sum over more dimensions than input. '
'%i > %i,%i' %
(axesdim, x.type.ndim, y.type.ndim))
outdim = x.type.ndim + y.type.ndim - 2 * axesdim
output = tensor(dtype=scal.upcast(x.dtype, y.dtype),
broadcastable=[False] * outdim)
return Apply(op, inputs=[x, y], outputs=[output, ])
def perform(self, node, inp, out):
x, y = inp
z, = out
try:
z[0] = numpy.asarray(numpy.tensordot(x, y, self.axes))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to
# add that.
e.args = e.args + (x.shape, y.shape, self.axes)
raise
def infer_shape(self, node, in_shapes):
shape_x, shape_y = in_shapes
out_shape = []
if isinstance(self.axes, (list, tuple)):
iter = (i for i in range(len(shape_x)) if i not in self.axes[0])
for i in iter:
out_shape.append(shape_x[i])
iter = (i for i in range(len(shape_y)) if i not in self.axes[1])
for i in iter:
out_shape.append(shape_y[i])
else:
out_shape = list(shape_x)[shape_x.ndim - self.axes] + \
list(shape_y)[shape_y.ndim - self.axes, shape_y.ndim]
return [out_shape]
def grad(self, inp, grads):
x, y = inp
gz, = grads
gx, gy = tensordot_grad(self.axes)(x, y, gz)
return [gx, gy]
def __str__(self):
return "tensordot"
def tensordot(x, y=None, axes=2): def tensordot(a, b, axes = 2):
""" """
Compute tensor-tensor dot products along specified axes. Given two tensors a and b,tensordot computes a generalized dot product over
the provided axes. This implementation reduces all expressions to matrix or
vector dot products and is based on code from Tijmen Tieleman's gnumpy
(http://www.cs.toronto.edu/~tijmen/gnumpy.html).
Given two tensors A and B, TensorDot takes the product of elements :param a: the first tensor variable
in A and B and sums that result over the provided axes.
:param x: the first tensor variable :param b: the second tensor variable
:param y: the second tensor variable
:param axes: an integer or array. If an integer, the number of axes :param axes: an integer or array. If an integer, the number of axes
to sum over. If an array, it must have two array to sum over. If an array, it must have two array
elements containing the axes to sum over in each tensor. elements containing the axes to sum over in each tensor.
Note that the default value of 2 is not guaranteed to work Note that the default value of 2 is not guaranteed to work
for all values of x and y, and an error will be raised if for all values of a and b, and an error will be raised if
that is the case. The reason for keeping the default is to that is the case. The reason for keeping the default is to
maintain the same signature as numpy's tensordot function maintain the same signature as numpy's tensordot function
(and np.tensordot raises analogous errors for non-compatible (and np.tensordot raises analogous errors for non-compatible
...@@ -7295,99 +7136,129 @@ def tensordot(x, y=None, axes=2): ...@@ -7295,99 +7136,129 @@ def tensordot(x, y=None, axes=2):
If an integer i, it is converted to an array containing If an integer i, it is converted to an array containing
the last i dimensions of the first tensor and the first the last i dimensions of the first tensor and the first
i dimensions of the second tensor: i dimensions of the second tensor:
axes = [range(x.ndim - i, x.ndim), range(i)] axes = [range(a.ndim - i, b.ndim), range(i)]
If an array, its two elements must contain compatible axes If an array, its two elements must contain compatible axes
of the two tensors. For example, [[1, 2], [2, 0]] means sum of the two tensors. For example, [[1, 2], [2, 0]] means sum
over the 2nd and 3rd axes of x and the 3rd and 1st axes of y. over the 2nd and 3rd axes of a and the 3rd and 1st axes of b.
(Remember axes are zero-indexed!) The 2nd axis of x and the (Remember axes are zero-indexed!) The 2nd axis of a and the
3rd axis of y must have the same shape; the same is true for 3rd axis of b must have the same shape; the same is true for
the 3rd axis of x and the 1st axis of y. the 3rd axis of a and the 1st axis of b.
:returns: a tensor with shape equal to the concatenation of x's shape :returns: a tensor with shape equal to the concatenation of a's shape
(less any dimensions that were summed over) and y's shape (less any dimensions that were summed over) and b's shape
(less any dimensions that were summed over). (less any dimensions that were summed over).
It may be helpful to consider an example to see what tensordot does. It may be helpful to consider an example to see what tensordot does.
Theano's implementation is identical to NumPy's. Here x has shape (2, 3, 4) Theano's implementation is identical to NumPy's. Here a has shape (2, 3, 4)
and y has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] -- and b has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
note that x.shape[1] == y.shape[3] and x.shape[2] == y.shape[2]; these axes note that a.shape[1] == b.shape[3] and a.shape[2] == b.shape[2]; these axes
are compatible. The resulting tensor will have shape (2, 5, 6) -- the are compatible. The resulting tensor will have shape (2, 5, 6) -- the
dimensions that are not being summed: dimensions that are not being summed:
x = np.random.random((2,3,4)) a = np.random.random((2,3,4))
y = np.random.random((5,6,4,3)) b = np.random.random((5,6,4,3))
#tensordot #tensordot
z = np.tensordot(x, y, [[1,2],[3,2]]) c = np.tensordot(a, b, [[1,2],[3,2]])
#loop replicating tensordot #loop replicating tensordot
x0, x1, x2 = x.shape a0, a1, a2 = a.shape
y0, y1, _, _ = y.shape b0, b1, _, _ = b.shape
zloop = np.zeros((x0,y0,y1)) cloop = np.zeros((a0,b0,b1))
#loop over non-summed indices -- these exist #loop over non-summed indices -- these exist
#in the tensor product. #in the tensor product.
for i in range(x0): for i in range(a0):
for j in range(y0): for j in range(b0):
for k in range(y1): for k in range(b1):
#loop over summed indices -- these don't exist #loop over summed indices -- these don't exist
#in the tensor product. #in the tensor product.
for l in range(x1): for l in range(a1):
for m in range(x2): for m in range(a2):
zloop[i,j,k] += x[i,l,m] * y[j,k,m,l] cloop[i,j,k] += a[i,l,m] * b[j,k,m,l]
np.allclose(c, cloop) #true
np.all(z == zloop) #true This specific implementation avoids a loop by transposing a and b such that
the summed axes of a are last and the summed axes of b are first. The
resulting arrays are reshaped to 2 dimensions (or left as vectors, if
appropriate) and a matrix or vector dot product is taken. The result is
reshaped back to the required output dimensions.
In an extreme case, no axes may be specified. The resulting tensor In an extreme case, no axes may be specified. The resulting tensor
will have shape equal to the concatenation of the shapes of x and y: will have shape equal to the concatenation of the shapes of a and b:
z = np.tensordot(x, y, 0) c = np.tensordot(a, b, 0)
print(x.shape) #(2,3,4) print(a.shape) #(2,3,4)
print(y.shape) #(5,6,4,3) print(b.shape) #(5,6,4,3)
print(z.shape) #(2,3,4,5,6,4,3) print(c.shape) #(2,3,4,5,6,4,3)
See the documentation of np.tensordot for more examples. See the documentation of np.tensordot for more examples.
""" """
if y is None: # axes must be a scalar or list/tuple of length 2
raise NotImplementedError( if not numpy.isscalar(axes) and len(axes) != 2:
'The interface to tensordot has changed from ' raise ValueError('Axes should be scalar valued or a '
'tensor.tensordot(axes)(x,y) to tensor.tensordot(x,y,axes). ' 'list/tuple of len 2.')
'Please modify your code accordingly.')
# if 'axes' is a number of axes to multiply and sum over (trailing axes
if x.ndim == 0 or y.ndim == 0: # of a, leading axes of b), we can just reshape and use dot.
raise ValueError('Cannot perform tensordot of 0-d inputs.') elif numpy.isscalar(axes):
# check if axes is valid given the dimension of a and b
axes = TensorDot.parse_axes(axes) if axes > a.ndim or axes > b.ndim:
raise ValueError('axes should be smaller than the dimension of '
# check whether axes is valid given the dimensions of x and y 'a and b (a.ndim=%i, b.ndim=%i)' % (a.ndim, b.ndim))
if numpy.isscalar(axes):
if axes >= x.ndim or axes >= y.ndim: outshape = concatenate([a.shape[:a.ndim - axes], b.shape[axes:]])
raise ValueError('axes should be smaller than the dimension of '\ outndim = a.ndim + b.ndim - (2 * axes)
'x and y (x.ndim=%i, y.ndim=%i)' % (x.ndim, y.ndim)) a_reshaped = a.reshape((prod(a.shape[:a.ndim - axes]),
elif isinstance(axes, (list, tuple)): prod(a.shape[a.ndim - axes:])),
ndim = 2)
if isinstance(axes[0], (list, tuple)) and \ b_reshaped = b.reshape((prod(b.shape[:axes]),
(len(axes[0]) > x.ndim or (numpy.array(axes[0]) >= x.ndim).any()): prod(b.shape[axes:])),
raise ValueError('axes[0] should be array_like, of length smaller'\ ndim = 2)
' than the dimension of x (x.ndim=%i, len(axes[0])=%i).' %
(x.ndim, len(axes[0]))) return dot(a_reshaped, b_reshaped).reshape(outshape, outndim)
if isinstance(axes[1], (list, tuple)) and \ # if 'axes' is a list, transpose a and b such that the summed axes of a
(len(axes[1]) > y.ndim or (numpy.array(axes[1]) >= y.ndim).any()): # are last and the summed axes of b are first.
raise ValueError('axes[1] should be array_like, of length smaller'\ else:
'than the dimension of y (y.ndim=%i, len(axes[1])=%i).' % a_axes, b_axes = tuple(axes[0]), tuple(axes[1])
(y.ndim, len(axes[1])))
# check that axes is valid given dimension of a and b
if not hasattr(tensordot, 'op'): if len(a_axes) > a.ndim:
tensordot.op = {} raise ValueError('axes[0] should be array_like, of length '
'smaller than the dimension of a '
if axes not in tensordot.op: '(a.ndim=%i, len(axes[0])=%i).' %
tensordot.op[axes] = TensorDot(axes) (a.ndim, a_axes))
if numpy.max(numpy.array(a_axes)) > a.ndim:
return tensordot.op[axes](x, y) raise ValueError('axes[0] contains dimensions higher than a.ndim '
'(a.ndim=%i, max(axes[0])=%i).' %
# TODO: tensordot should be function as described in rst docs. (a.ndim, numpy.max(numpy.array(a_axes))))
if len(b_axes) > b.ndim:
raise ValueError('axes[1] should be array_like, of length '
'smaller than the dimension of b '
'(a.ndim=%i, len(axes[0])=%i).' %
(b.ndim, b_axes))
if numpy.max(numpy.array(b_axes)) > b.ndim:
raise ValueError('axes[1] contains dimensions higher than b.ndim '
'(b.ndim=%i, max(axes[1])=%i).' %
(b.ndim, numpy.max(numpy.array(b_axes))))
# the two axes lists must have the same length
if len(a_axes) != len(b_axes):
raise ValueError('Axes elements must have the same length.')
a_order = (tuple(x for x in tuple(xrange(a.ndim)) if x not in a_axes)
+ a_axes)
b_order = (b_axes
+ tuple(x for x in tuple(xrange(b.ndim)) if x not in b_axes))
a_shuffled = a.dimshuffle(a_order)
b_shuffled = b.dimshuffle(b_order)
# now that a and b are in the right order, call tensordot recursively
return tensordot(a_shuffled, b_shuffled, len(a_axes))
def outer(x, y): def outer(x, y):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论