提交 f6c48641 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: implement and test

上级 713fd0fe
......@@ -3372,43 +3372,236 @@ def transpose(x, axes=None):
return ret
def batched_dot(x, y):
class BatchedDot(Op):
"""
This function computes the dot product between the two tensors, by
iterating over the first dimension using scan.
Dot but Batched
Parameters
----------
x : tensor
A Tensor with sizes e.g.: for 3D (dim1, dim3, dim2).
y : tensor
A Tensor with sizes e.g.: for 3D (dim1, dim2, dim4).
WRITEME
Computes the dot product of two variables. For two matrices, this is
equivalent to matrix multiplication. For two vectors, this is the inner
product.
Returns
-------
tensor
A tensor of size e.g. if it is 3D: (dim1, dim3, dim4).
Notes
-----
Matrix-matrix products are sometimes optimized to Dot22 or Gemm ops
(see tensor.blas).
Vector-vector products are sometimes optimized to Ger or CGer (see
tensor.blas).
Matrix-vector products are sometimes optimized to Gemv, CGemv (see
tensor.blas).
"""
__props__ = ()
def make_node(self, *inputs):
inputs = list(map(as_tensor_variable, inputs))
if len(inputs) != 2:
raise TypeError(
'theano.tensor.BatchedDot: 2 arguments required, %d given ' %
len(inputs))
if inputs[0].ndim not in (2, 3):
raise TypeError(
'theano.tensor.BatchedDot: input 0 (0-indexed) must have ndim'
' of 2 or 3, %d given. Consider calling '
'theano.tensor.batched_dot instead.' % inputs[0].ndim)
if inputs[1].ndim not in (2, 3):
raise TypeError(
'theano.tensor.BatchedDot: input 1 (0-indexed) must have ndim'
'of 2 or 3, %d given. Consider calling '
'theano.tensor.batched_dot instead.' % inputs[1].ndim)
dtype = scal.upcast(*[input.type.dtype for input in inputs])
broadcastable = (inputs[0].type.broadcastable[:-1] +
inputs[1].type.broadcastable[2:])
return Apply(self, inputs, [tensor(dtype, broadcastable)])
def perform(self, node, inp, out):
x, y = inp
z, = out
if x.shape[0] != y.shape[0]:
raise TypeError(
'theano.tensor.BatchedDot: inputs [%s] must have the same size'
'in axis 0, but have sizes [%s].'
% (", ".join(map(str, inp)),
", ".join([str(i.shape[0]) for i in inp])))
shape = self.infer_shape(node, [i.shape for i in inp])[0]
dtype = node.outputs[0].dtype
z[0] = numpy.empty(shape, dtype=dtype)
for i in xrange(z[0].shape[0]):
z[0][i] = numpy.dot(x[i], y[i])
def grad(self, inp, grads):
x, y = inp
gz, = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
# grad is scalar, so x is vector and y is vector
if gdim == 0:
xgrad = gz.dimshuffle(0, 'x') * y
ygrad = gz.dimshuffle(0, 'x') * x
# x is vector, y is matrix, grad is vector
elif xdim == 2 and ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = x.dimshuffle(0, 1, 'x') * gz.dimshuffle(0, 'x', 1)
# x is matrix, y is vector, grad is vector
elif xdim == 2 and ydim == 1:
xgrad = gz.dimshuffle(0, 1, 'x') * y.dimshuffle(0, 'x', 1)
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
# x is matrix, y is matrix, grad is matrix
elif xdim == ydim == 2:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable:
xgrad = patternbroadcast(xgrad, x.broadcastable)
if ygrad.broadcastable != y.broadcastable:
ygrad = patternbroadcast(ygrad, y.broadcastable)
rval = xgrad, ygrad
for elem in rval:
assert elem.dtype.find('float') != -1
return rval
def R_op(self, inputs, eval_points):
# R_op for a \dot b evaluted at c for a and d for b is
# simply c \dot b + a \dot d
assert len(inputs) == 2
assert len(eval_points) == 2
if eval_points[0] is None and eval_points[1] is None:
return [None]
debugger_available = config.compute_test_value != 'off'
if debugger_available:
try:
iv0 = gof.op.get_test_value(inputs[0])
except AttributeError:
gof.op.missing_test_message(
'first input passed to Dot.R_op has no test value')
debugger_available = False
try:
iv1 = gof.op.get_test_value(inputs[1])
except AttributeError:
gof.op.missing_test_message(
'second input passed to Dot.R_op has no test value')
debugger_available = False
if eval_points[0]:
try:
ev0 = gof.op.get_test_value(eval_points[0])
except AttributeError:
gof.op.missing_test_message(
'first eval point passed to Dot.R_op '
'has no test value')
debugger_available = False
if eval_points[1]:
try:
ev1 = gof.op.get_test_value(eval_points[1])
except AttributeError:
gof.op.missing_test_message(
'second eval point passed to Dot.R_op '
'has no test value')
debugger_available = False
if debugger_available:
input_values = [iv0, iv1]
eval_point_values = [ev0, ev1]
for i in xrange(2):
if eval_point_values[i] is not None and \
input_values[i].shape != eval_point_values[i].shape:
raise ValueError(
'input ' + str(i) + ' and eval_point ' + str(i) +
' to Dot.R_op should have the same shape, but '
'their shapes are %s and %s, respectively' % (
str(input_values[i].shape),
str(eval_point_values[i].shape)))
if eval_points[0]:
t1 = self(eval_points[0], inputs[1])
if eval_points[1]:
t2 = self(inputs[0], eval_points[1])
if eval_points[0] and eval_points[1]:
return [t1 + t2]
elif eval_points[0]:
return [t1]
else:
return [t2]
def infer_shape(self, node, shapes):
for shape_ in shapes:
if len(shape_) not in (2, 3):
raise NotImplementedError()
xshp, yshp = shapes
return [xshp[:-1] + yshp[2:]]
def __str__(self):
return "batched_dot"
def batched_dot(a, b):
"""
WRITEME
Computes the dot product of two variables.
For two matrices, this is equivalent to matrix multiplication.
For two vectors, this is the inner product.
When one variable is a scalar, this is like elementwise multiplication.
For N dimensions, this is a sum product over the last axis
of the first array and the second-to-last axis of the second array:
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
Note that this dot function does one of three things, in the following
sequence:
1. If either a or b is scalar, it returns the elementwise product
without calling the Theano Dot op.
2. If either a or b has more than 2 dimensions, it calls Theano's
tensordot function with appropriate axes. The tensordot function
expresses high-dimensional dot products in terms of 2D matrix
multiplications, so it may be possible to futherize optimize for
performance.
3. If both a and b have either 1 or 2 dimensions, it calls Theano's
Dot op on a and b.
Notes
-----
This is a subset of numpy.einsum, but we do not provide it for now.
But numpy einsum is slower than dot or tensordot:
http://mail.scipy.org/pipermail/numpy-discussion/2012-October/064259.html
Matrix-matrix products are sometimes optimized to Dot22 or Gemm ops
(see tensor.blas).
Vector-vector products are sometimes optimized to Ger or CGer (see
tensor.blas).
Matrix-vector products are sometimes optimized to Gemv, CGemv (see
tensor.blas).
Examples
--------
>>> first = tensor.tensor3('first')
>>> second = tensor.tensor3('second')
>>> result = batched_dot(first, second)
"""
result, updates = theano.scan(
fn=lambda x_mat, y_mat:
theano.tensor.dot(x_mat, y_mat),
outputs_info=None,
sequences=[x, y],
non_sequences=None)
return result
"""
a, b = as_tensor_variable(a), as_tensor_variable(b)
if a.ndim == 1:
return a.dimshuffle(*([0] + ["x"] * (b.ndim - 1))) * b
elif b.ndim == 1:
return a * b.dimshuffle(*([0] + ["x"] * (a.ndim - 1)))
elif a.ndim > 3 or b.ndim > 3:
return batched_tensordot(
a, b, [[a.ndim - 1], [numpy.maximum(1, b.ndim - 2)]])
else:
return BatchedDot()(a, b)
def batched_tensordot(x, y, axes=2):
......
......@@ -31,9 +31,9 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
horizontal_stack, vertical_stack, argmax, get_vector_length,
fscalar, zeros_like, sum, tensor3, vector, add, addbroadcast,
alloc, as_tensor_variable, tensor_from_scalar, ARange, autocast_float,
clip, constant, default, dot,
dmatrix, dscalar, dvector, eq, eye, fill, flatten, inverse_permutation, Flatten,
tensor4, permute_row_elements, fmatrix, fscalars, grad,
clip, constant, default, dot, batched_dot,
dmatrix, dscalar, dvector, eq, eye, fill, flatten, inverse_permutation,
tensor4, permute_row_elements, Flatten, fmatrix, fscalars, grad,
inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq,
Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor,
tensor_copy, tensordot, TensorType, Tri, tri, tril, triu, unbroadcast,
......@@ -1932,6 +1932,36 @@ DotTester = makeTester(name='DotTester',
bad_runtime=dict(bad1=(rand(5, 7), rand(5, 7)),
bad2=(rand(5, 7), rand(8, 3))))
BatchedDotTester = makeTester(
name='BatchedDotTester',
op=batched_dot,
expected=lambda xs, ys: numpy.asarray(list(
numpy.dot(x, y) for x, y in zip(xs, ys))),
checks={},
good=dict(correct1=(rand(3, 5, 7), rand(3, 7, 5)),
correct2=(rand(3, 5, 7), rand(3, 7, 9)),
correct3=(rand(3, 5, 7), rand(3, 7)),
correct4=(rand(3, 5), rand(3, 5, 7)),
correct5=(rand(3), rand(3, 5, 7)),
correct6=(rand(3, 5), rand(3)),
mixed1=(rand(3, 5).astype('float32'),
rand(3, 5, 7)),
mixed2=(rand(3, 5).astype('float64'),
rand(3, 5, 7)),
complex1=(randcomplex(3, 5, 7),
randcomplex(3, 7)),
complex2=(rand(3, 5, 7), randcomplex(3, 7)),
complex3=(randcomplex(3, 5, 7), rand(3, 7)),
empty1=(numpy.asarray([], dtype=config.floatX),
numpy.asarray([], dtype=config.floatX)),
empty2=(rand(3, 5, 0), rand(3, 0, 2)),
empty3=(rand(3, 0, 5), rand(3, 5, 0))),
bad_build=dict(),
bad_runtime=dict(bad1=(rand(3, 5, 7), rand(3, 5, 7)),
bad2=(rand(3, 5, 7), rand(3, 8, 3)),
bad3=(rand(2, 5, 7), rand(3, 7, 9)),
bad4=(rand(3, 5, 7), rand(2, 7, 9))))
def _numpy_second(x, y):
return numpy.broadcast_arrays(x, y)[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论