提交 86eea375 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

Update dot op to support n-dimensional variables (numpy semantics)

-call numpy.dot for inputs with dim == 0 or dim > 2 -add gradient calcs for various cases -add infer_shape for various cases -update docstring
上级 9555a94f
...@@ -6867,7 +6867,14 @@ def take(a, indices, axis=None, mode='raise'): ...@@ -6867,7 +6867,14 @@ def take(a, indices, axis=None, mode='raise'):
class Dot(Op): class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products. """
Computes the dot product of two variables. For two matrices, this is
equivalent to matrix multiplication. For two vectors, this is the inner
product. When one variable is a scalar, it is like elementwise
multiplication. For N dimensions, it is a sum product over the last axis
of the first array and the second-to-last axis of the second array:
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
:note: matrix-matrix products are sometimes optimized to Dot22 ops :note: matrix-matrix products are sometimes optimized to Dot22 ops
(see tensor.blas) (see tensor.blas)
...@@ -6890,51 +6897,21 @@ class Dot(Op): ...@@ -6890,51 +6897,21 @@ class Dot(Op):
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = map(as_tensor_variable, inputs) inputs = map(as_tensor_variable, inputs)
numpy_semantics = 0 if len(inputs) != 2:
if numpy_semantics: raise TypeError(
# numpy defines dot for tensor pairs with any rank 'theanor.tensor.Dot: 2 arguments required, %d given ' %
if len(inputs) != 2: len(inputs))
raise TypeError( i_broadcastables = [input.type.broadcastable for input in inputs]
"Wrong number of inputs for %s (got %i, expected 2)" % bx, by = i_broadcastables
self) if len(bx) == 0: # x is a scalar
i_broadcastables = [input.type.broadcastable for input in inputs] bz = by
bx, by = i_broadcastables
if len(bx) == 0: # x is a scalar
bz = by
else:
if len(by) >= 2: # y is a matrix or tensor
bz = bx[:-1] + by[:-2] + by[-1:]
elif len(by) == 1: # y is vector
bz = bx[:-1]
else: # y is a scalar
bz = bx
else: else:
if len(inputs) != 2: if len(by) >= 2: # y is a matrix or tensor
raise TypeError( bz = bx[:-1] + by[:-2] + by[-1:]
'theanor.tensor.Dot: 2 arguments required, %d given ' % elif len(by) == 1: # y is vector
len(inputs)) bz = bx[:-1]
else: # y is a scalar
x, y = inputs bz = bx
nx = x.type.ndim
ny = y.type.ndim
if nx not in (1, 2):
raise TypeError(
('dot supports matrix and vector args: email theano-dev '
'about enabling numpy dot semantics if you want them'), x)
if ny not in (1, 2):
raise TypeError(
('dot supports matrix and vector args: email theano-dev '
'about enabling numpy dot semantics if you want them'), y)
if nx == 2 and ny == 2:
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
elif nx == 1 and ny == 2:
bz = [y.type.broadcastable[1]]
elif nx == 2 and ny == 1:
bz = [x.type.broadcastable[0]]
else:
bz = []
i_dtypes = [input.type.dtype for input in inputs] i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(scal.upcast(*i_dtypes), bz)] outputs = [tensor(scal.upcast(*i_dtypes), bz)]
...@@ -6966,14 +6943,50 @@ class Dot(Op): ...@@ -6966,14 +6943,50 @@ class Dot(Op):
x, y = inp x, y = inp
gz, = grads gz, = grads
xdim, ydim = x.type.ndim, y.type.ndim
#grad is scalar
if gz.type.ndim == 0: if gz.type.ndim == 0:
rval = gz * y, gz * x xgrad = gz * y
elif x.type.ndim == 1 and y.type.ndim > 1: ygrad = gz * x
rval = dot(gz, y.T), outer(x.T, gz) #x is scalar
elif x.type.ndim > 1 and y.type.ndim == 1: elif xdim == 0:
rval = outer(gz, y.T), dot(x.T, gz) xgrad = (gz * y).sum()
ygrad = x * gz
#y is scalar
elif ydim == 0:
xgrad = y * gz
ygrad = (gz * x).sum()
#x is vector, y is matrix
elif xdim == 1 and ydim == 2:
xgrad = dot(gz, y.T)
ygrad = outer(x.T, gz)
#x is matrix, y is vector
elif xdim == 2 and ydim == 1:
xgrad = outer(gz, y.T)
ygrad = dot(x.T, gz)
#x is matrix, y is matrix
elif xdim == ydim == 2:
xgrad = dot(gz, y.T)
ygrad = dot(x.T, gz)
#x is tensor, y is vector (corner case)
elif xdim > 2 and ydim == 1:
xgrad = tensordot(y, gz, 0).transpose(range(xdim)[1:] + [0])
ygrad = tensordot(x, gz, [range(xdim - 1)] * 2)
#x or y is tensor
else: else:
rval = dot(gz, y.T), dot(x.T, gz) sum0, sum1 = range(xdim), range(xdim - 1)
sum0.pop(-1)
dims = range(ydim)
dims[-1:-1] = [dims.pop(0)]
ygrad = tensordot(x, gz, [sum0, sum1]).transpose(dims)
sum0, sum1 = range(ydim), range(xdim - 1, xdim + ydim - 2)
sum0.pop(-2)
dims = range(xdim)[1:] + [0]
xgrad = tensordot(y, gz, [sum0, sum1]).transpose(dims)
rval = xgrad, ygrad
for elem in rval: for elem in rval:
assert elem.dtype.find('float') != -1 assert elem.dtype.find('float') != -1
...@@ -7041,14 +7054,28 @@ class Dot(Op): ...@@ -7041,14 +7054,28 @@ class Dot(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xshp, yshp = shapes xshp, yshp = shapes
x, y = node.inputs x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])] #scalar / scalar
if x.ndim == 1 and y.ndim == 2: if x.ndim == 0 and y.ndim == 0:
return [(yshp[1],)] return [()]
if x.ndim == 2 and y.ndim == 1: #not scalar / scalar
return [(xshp[0],)] if x.ndim != 0 and y.ndim == 0:
return [xshp]
#scalar / not scalar
if x.ndim == 0 and y.ndim != 0:
return [yshp]
#vector / vector
if x.ndim == 1 and y.ndim == 1: if x.ndim == 1 and y.ndim == 1:
return [()] return [()]
#tensor / vector
if x.ndim > 1 and y.ndim == 1:
return [xshp[:-1]]
#vector / tensor
if x.ndim == 1 and y.ndim > 1:
return [yshp[:-2] + yshp[-1:]]
#tensor / tensor
if x.ndim > 1 and y.ndim > 1:
return [xshp[:-1] + yshp[:-2] + yshp[-1:]]
raise NotImplementedError() raise NotImplementedError()
def __str__(self): def __str__(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论