提交 f6c48641 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: implement and test

上级 713fd0fe
...@@ -3372,43 +3372,236 @@ def transpose(x, axes=None): ...@@ -3372,43 +3372,236 @@ def transpose(x, axes=None):
return ret return ret
def batched_dot(x, y): class BatchedDot(Op):
""" """
This function computes the dot product between the two tensors, by Dot but Batched
iterating over the first dimension using scan.
Parameters WRITEME
---------- Computes the dot product of two variables. For two matrices, this is
x : tensor equivalent to matrix multiplication. For two vectors, this is the inner
A Tensor with sizes e.g.: for 3D (dim1, dim3, dim2). product.
y : tensor
A Tensor with sizes e.g.: for 3D (dim1, dim2, dim4).
Returns Notes
------- -----
tensor Matrix-matrix products are sometimes optimized to Dot22 or Gemm ops
A tensor of size e.g. if it is 3D: (dim1, dim3, dim4). (see tensor.blas).
Vector-vector products are sometimes optimized to Ger or CGer (see
tensor.blas).
Matrix-vector products are sometimes optimized to Gemv, CGemv (see
tensor.blas).
"""
__props__ = ()
def make_node(self, *inputs):
inputs = list(map(as_tensor_variable, inputs))
if len(inputs) != 2:
raise TypeError(
'theano.tensor.BatchedDot: 2 arguments required, %d given ' %
len(inputs))
if inputs[0].ndim not in (2, 3):
raise TypeError(
'theano.tensor.BatchedDot: input 0 (0-indexed) must have ndim'
' of 2 or 3, %d given. Consider calling '
'theano.tensor.batched_dot instead.' % inputs[0].ndim)
if inputs[1].ndim not in (2, 3):
raise TypeError(
'theano.tensor.BatchedDot: input 1 (0-indexed) must have ndim'
'of 2 or 3, %d given. Consider calling '
'theano.tensor.batched_dot instead.' % inputs[1].ndim)
dtype = scal.upcast(*[input.type.dtype for input in inputs])
broadcastable = (inputs[0].type.broadcastable[:-1] +
inputs[1].type.broadcastable[2:])
return Apply(self, inputs, [tensor(dtype, broadcastable)])
def perform(self, node, inp, out):
x, y = inp
z, = out
if x.shape[0] != y.shape[0]:
raise TypeError(
'theano.tensor.BatchedDot: inputs [%s] must have the same size'
'in axis 0, but have sizes [%s].'
% (", ".join(map(str, inp)),
", ".join([str(i.shape[0]) for i in inp])))
shape = self.infer_shape(node, [i.shape for i in inp])[0]
dtype = node.outputs[0].dtype
z[0] = numpy.empty(shape, dtype=dtype)
for i in xrange(z[0].shape[0]):
z[0][i] = numpy.dot(x[i], y[i])
def grad(self, inp, grads):
x, y = inp
gz, = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
# grad is scalar, so x is vector and y is vector
if gdim == 0:
xgrad = gz.dimshuffle(0, 'x') * y
ygrad = gz.dimshuffle(0, 'x') * x
# x is vector, y is matrix, grad is vector
elif xdim == 2 and ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = x.dimshuffle(0, 1, 'x') * gz.dimshuffle(0, 'x', 1)
# x is matrix, y is vector, grad is vector
elif xdim == 2 and ydim == 1:
xgrad = gz.dimshuffle(0, 1, 'x') * y.dimshuffle(0, 'x', 1)
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
# x is matrix, y is matrix, grad is matrix
elif xdim == ydim == 2:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable:
xgrad = patternbroadcast(xgrad, x.broadcastable)
if ygrad.broadcastable != y.broadcastable:
ygrad = patternbroadcast(ygrad, y.broadcastable)
rval = xgrad, ygrad
for elem in rval:
assert elem.dtype.find('float') != -1
return rval
def R_op(self, inputs, eval_points):
# R_op for a \dot b evaluted at c for a and d for b is
# simply c \dot b + a \dot d
assert len(inputs) == 2
assert len(eval_points) == 2
if eval_points[0] is None and eval_points[1] is None:
return [None]
debugger_available = config.compute_test_value != 'off'
if debugger_available:
try:
iv0 = gof.op.get_test_value(inputs[0])
except AttributeError:
gof.op.missing_test_message(
'first input passed to Dot.R_op has no test value')
debugger_available = False
try:
iv1 = gof.op.get_test_value(inputs[1])
except AttributeError:
gof.op.missing_test_message(
'second input passed to Dot.R_op has no test value')
debugger_available = False
if eval_points[0]:
try:
ev0 = gof.op.get_test_value(eval_points[0])
except AttributeError:
gof.op.missing_test_message(
'first eval point passed to Dot.R_op '
'has no test value')
debugger_available = False
if eval_points[1]:
try:
ev1 = gof.op.get_test_value(eval_points[1])
except AttributeError:
gof.op.missing_test_message(
'second eval point passed to Dot.R_op '
'has no test value')
debugger_available = False
if debugger_available:
input_values = [iv0, iv1]
eval_point_values = [ev0, ev1]
for i in xrange(2):
if eval_point_values[i] is not None and \
input_values[i].shape != eval_point_values[i].shape:
raise ValueError(
'input ' + str(i) + ' and eval_point ' + str(i) +
' to Dot.R_op should have the same shape, but '
'their shapes are %s and %s, respectively' % (
str(input_values[i].shape),
str(eval_point_values[i].shape)))
if eval_points[0]:
t1 = self(eval_points[0], inputs[1])
if eval_points[1]:
t2 = self(inputs[0], eval_points[1])
if eval_points[0] and eval_points[1]:
return [t1 + t2]
elif eval_points[0]:
return [t1]
else:
return [t2]
def infer_shape(self, node, shapes):
for shape_ in shapes:
if len(shape_) not in (2, 3):
raise NotImplementedError()
xshp, yshp = shapes
return [xshp[:-1] + yshp[2:]]
def __str__(self):
return "batched_dot"
def batched_dot(a, b):
"""
WRITEME
Computes the dot product of two variables.
For two matrices, this is equivalent to matrix multiplication.
For two vectors, this is the inner product.
When one variable is a scalar, this is like elementwise multiplication.
For N dimensions, this is a sum product over the last axis
of the first array and the second-to-last axis of the second array:
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
Note that this dot function does one of three things, in the following
sequence:
1. If either a or b is scalar, it returns the elementwise product
without calling the Theano Dot op.
2. If either a or b has more than 2 dimensions, it calls Theano's
tensordot function with appropriate axes. The tensordot function
expresses high-dimensional dot products in terms of 2D matrix
multiplications, so it may be possible to futherize optimize for
performance.
3. If both a and b have either 1 or 2 dimensions, it calls Theano's
Dot op on a and b.
Notes Notes
----- -----
This is a subset of numpy.einsum, but we do not provide it for now. Matrix-matrix products are sometimes optimized to Dot22 or Gemm ops
But numpy einsum is slower than dot or tensordot: (see tensor.blas).
http://mail.scipy.org/pipermail/numpy-discussion/2012-October/064259.html Vector-vector products are sometimes optimized to Ger or CGer (see
tensor.blas).
Matrix-vector products are sometimes optimized to Gemv, CGemv (see
tensor.blas).
Examples """
-------- a, b = as_tensor_variable(a), as_tensor_variable(b)
>>> first = tensor.tensor3('first')
>>> second = tensor.tensor3('second') if a.ndim == 1:
>>> result = batched_dot(first, second) return a.dimshuffle(*([0] + ["x"] * (b.ndim - 1))) * b
elif b.ndim == 1:
""" return a * b.dimshuffle(*([0] + ["x"] * (a.ndim - 1)))
result, updates = theano.scan( elif a.ndim > 3 or b.ndim > 3:
fn=lambda x_mat, y_mat: return batched_tensordot(
theano.tensor.dot(x_mat, y_mat), a, b, [[a.ndim - 1], [numpy.maximum(1, b.ndim - 2)]])
outputs_info=None, else:
sequences=[x, y], return BatchedDot()(a, b)
non_sequences=None)
return result
def batched_tensordot(x, y, axes=2): def batched_tensordot(x, y, axes=2):
......
...@@ -31,9 +31,9 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -31,9 +31,9 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
horizontal_stack, vertical_stack, argmax, get_vector_length, horizontal_stack, vertical_stack, argmax, get_vector_length,
fscalar, zeros_like, sum, tensor3, vector, add, addbroadcast, fscalar, zeros_like, sum, tensor3, vector, add, addbroadcast,
alloc, as_tensor_variable, tensor_from_scalar, ARange, autocast_float, alloc, as_tensor_variable, tensor_from_scalar, ARange, autocast_float,
clip, constant, default, dot, clip, constant, default, dot, batched_dot,
dmatrix, dscalar, dvector, eq, eye, fill, flatten, inverse_permutation, Flatten, dmatrix, dscalar, dvector, eq, eye, fill, flatten, inverse_permutation,
tensor4, permute_row_elements, fmatrix, fscalars, grad, tensor4, permute_row_elements, Flatten, fmatrix, fscalars, grad,
inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq, inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq,
Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor, Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor,
tensor_copy, tensordot, TensorType, Tri, tri, tril, triu, unbroadcast, tensor_copy, tensordot, TensorType, Tri, tri, tril, triu, unbroadcast,
...@@ -1932,6 +1932,36 @@ DotTester = makeTester(name='DotTester', ...@@ -1932,6 +1932,36 @@ DotTester = makeTester(name='DotTester',
bad_runtime=dict(bad1=(rand(5, 7), rand(5, 7)), bad_runtime=dict(bad1=(rand(5, 7), rand(5, 7)),
bad2=(rand(5, 7), rand(8, 3)))) bad2=(rand(5, 7), rand(8, 3))))
BatchedDotTester = makeTester(
name='BatchedDotTester',
op=batched_dot,
expected=lambda xs, ys: numpy.asarray(list(
numpy.dot(x, y) for x, y in zip(xs, ys))),
checks={},
good=dict(correct1=(rand(3, 5, 7), rand(3, 7, 5)),
correct2=(rand(3, 5, 7), rand(3, 7, 9)),
correct3=(rand(3, 5, 7), rand(3, 7)),
correct4=(rand(3, 5), rand(3, 5, 7)),
correct5=(rand(3), rand(3, 5, 7)),
correct6=(rand(3, 5), rand(3)),
mixed1=(rand(3, 5).astype('float32'),
rand(3, 5, 7)),
mixed2=(rand(3, 5).astype('float64'),
rand(3, 5, 7)),
complex1=(randcomplex(3, 5, 7),
randcomplex(3, 7)),
complex2=(rand(3, 5, 7), randcomplex(3, 7)),
complex3=(randcomplex(3, 5, 7), rand(3, 7)),
empty1=(numpy.asarray([], dtype=config.floatX),
numpy.asarray([], dtype=config.floatX)),
empty2=(rand(3, 5, 0), rand(3, 0, 2)),
empty3=(rand(3, 0, 5), rand(3, 5, 0))),
bad_build=dict(),
bad_runtime=dict(bad1=(rand(3, 5, 7), rand(3, 5, 7)),
bad2=(rand(3, 5, 7), rand(3, 8, 3)),
bad3=(rand(2, 5, 7), rand(3, 7, 9)),
bad4=(rand(3, 5, 7), rand(2, 7, 9))))
def _numpy_second(x, y): def _numpy_second(x, y):
return numpy.broadcast_arrays(x, y)[1] return numpy.broadcast_arrays(x, y)[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论