提交 1d17699f authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

Dot op only supports matrix/vector args; other ndims must go through tensor.dot…

Dot op only supports matrix/vector args; other ndims must go through tensor.dot to ensure optimizations can be applied
上级 67898e2e
...@@ -6899,19 +6899,25 @@ class Dot(Op): ...@@ -6899,19 +6899,25 @@ class Dot(Op):
if len(inputs) != 2: if len(inputs) != 2:
raise TypeError( raise TypeError(
'theanor.tensor.Dot: 2 arguments required, %d given ' % 'theano.tensor.Dot: 2 arguments required, %d given ' %
len(inputs)) len(inputs))
if inputs[0].ndim not in (1, 2):
raise TypeError(
'theano.tensor.Dot: input 0 (0-indexed) must have ndim of '
'1 or 2, %d given. Consider calling theano.tensor.dot '
'instead.' % inputs[0].ndim)
if inputs[1].ndim not in (1, 2):
raise TypeError(
'theano.tensor.Dot: input 1 (0-indexed) must have ndim of '
'1 or 2, %d given. Consider calling theano.tensor.dot '
'instead.' % inputs[1].ndim)
i_broadcastables = [input.type.broadcastable for input in inputs] i_broadcastables = [input.type.broadcastable for input in inputs]
bx, by = i_broadcastables bx, by = i_broadcastables
if len(bx) == 0: # x is a scalar if len(by) == 2: # y is a matrix
bz = by bz = bx[:-1] + by[-1:]
else:
if len(by) >= 2: # y is a matrix or tensor
bz = bx[:-1] + by[:-2] + by[-1:]
elif len(by) == 1: # y is vector elif len(by) == 1: # y is vector
bz = bx[:-1] bz = bx[:-1]
else: # y is a scalar
bz = bx
i_dtypes = [input.type.dtype for input in inputs] i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(scal.upcast(*i_dtypes), bz)] outputs = [tensor(scal.upcast(*i_dtypes), bz)]
...@@ -6924,12 +6930,6 @@ class Dot(Op): ...@@ -6924,12 +6930,6 @@ class Dot(Op):
# the asarray is here because dot between two vectors # the asarray is here because dot between two vectors
# gives a numpy float object but we need to return a 0d # gives a numpy float object but we need to return a 0d
# ndarray # ndarray
if x.ndim == 0 or y.ndim == 0:
z[0] = numpy.asarray(x * y)
elif x.ndim > 2 or y.ndim > 2:
axes = [x.ndim - 1, y.ndim - 2]
z[0] = numpy.asarray(numpy.tensordot(x, y, axes))
else:
z[0] = numpy.asarray(numpy.dot(x, y)) z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e: except ValueError, e:
# The error raised by numpy has no shape information, we mean to # The error raised by numpy has no shape information, we mean to
...@@ -6951,21 +6951,11 @@ class Dot(Op): ...@@ -6951,21 +6951,11 @@ class Dot(Op):
gz, = grads gz, = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
#grad is scalar, so x is scalar or vector and y is same as x #grad is scalar, so x is vector and y is vector
if gdim == 0: if gdim == 0:
xgrad = gz * y xgrad = gz * y
ygrad = gz * x ygrad = gz * x
#x is scalar, y is not scalar, grad.shape == y.shape
elif xdim == 0:
xgrad = (gz * y).sum()
ygrad = x * gz
#x is not scalar, y is scalar, grad.shape == x.shape
elif ydim == 0:
xgrad = y * gz
ygrad = (gz * x).sum()
#x is vector, y is matrix, grad is vector #x is vector, y is matrix, grad is vector
elif xdim == 1 and ydim == 2: elif xdim == 1 and ydim == 2:
xgrad = dot(gz, y.T) xgrad = dot(gz, y.T)
...@@ -6981,36 +6971,6 @@ class Dot(Op): ...@@ -6981,36 +6971,6 @@ class Dot(Op):
xgrad = dot(gz, y.T) xgrad = dot(gz, y.T)
ygrad = dot(x.T, gz) ygrad = dot(x.T, gz)
# x or y is tensor, grad is tensor
#
# the output grad has the same dim as the dot product output, namely
# x.shape[:-1] + y.shape[:-2] + y.shape[-1:]. To get the grad
# wrt x or y, a tensordot is used to sum out non-compatible dims.
#
# for grad wrt x:
# grad is a tensordot of the output grad and y, summing out all but
# the second-to-last dim of y. If y is a vector, no sum is taken.
#
# for grad wrt y:
# grad is a tensordot of the output grad and x, summing out
# all but the last dim of x. If y is not a vector, the tensordot is
# transposed so that its last dim becomes its second-to-last.
else:
gy_axes = range(xdim - 1, gdim)
if ydim != 1:
y_axes = [ax for ax in range(ydim) if ax != ydim - 2]
else:
y_axes = []
xgrad = tensordot(gz, y, [gy_axes, y_axes])
gx_axes = y_axes = range(xdim - 1)
if ydim != 1:
t_dims = range(ydim)
t_dims[-2], t_dims[-1] = t_dims[-1], t_dims[-2]
ygrad = tensordot(gz, x, [gx_axes, y_axes]).transpose(t_dims)
else:
ygrad = tensordot(gz, x, [gx_axes, y_axes])
rval = xgrad, ygrad rval = xgrad, ygrad
for elem in rval: for elem in rval:
...@@ -7080,27 +7040,18 @@ class Dot(Op): ...@@ -7080,27 +7040,18 @@ class Dot(Op):
xshp, yshp = shapes xshp, yshp = shapes
x, y = node.inputs x, y = node.inputs
#scalar / scalar # vector / vector
if x.ndim == 0 and y.ndim == 0:
return [()]
#not scalar / scalar
if x.ndim != 0 and y.ndim == 0:
return [xshp]
#scalar / not scalar
if x.ndim == 0 and y.ndim != 0:
return [yshp]
#vector / vector
if x.ndim == 1 and y.ndim == 1: if x.ndim == 1 and y.ndim == 1:
return [()] return [()]
#tensor / vector # matrix / vector
if x.ndim > 1 and y.ndim == 1: if x.ndim == 2 and y.ndim == 1:
return [xshp[:-1]] return [xshp[:-1]]
#vector / tensor # vector / matrix
if x.ndim == 1 and y.ndim > 1: if x.ndim == 1 and y.ndim == 2:
return [yshp[:-2] + yshp[-1:]] return [yshp[-1:]]
#tensor / tensor # matrix / matrix
if x.ndim > 1 and y.ndim > 1: if x.ndim == 2 and y.ndim == 2:
return [xshp[:-1] + yshp[:-2] + yshp[-1:]] return [xshp[:-1] + yshp[-1:]]
raise NotImplementedError() raise NotImplementedError()
def __str__(self): def __str__(self):
...@@ -7268,7 +7219,7 @@ def tensordot(a, b, axes = 2): ...@@ -7268,7 +7219,7 @@ def tensordot(a, b, axes = 2):
a_reshaped = a.reshape((a_shape_0, -1), ndim = 2) a_reshaped = a.reshape((a_shape_0, -1), ndim = 2)
b_reshaped = b.reshape((b_shape_0, -1), ndim = 2) b_reshaped = b.reshape((b_shape_0, -1), ndim = 2)
return dot(a_reshaped, b_reshaped).reshape(outshape, outndim) return _dot(a_reshaped, b_reshaped).reshape(outshape, outndim)
# if 'axes' is a list, transpose a and b such that the summed axes of a # if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first. # are last and the summed axes of b are first.
......
...@@ -4262,6 +4262,31 @@ class t_dot(unittest.TestCase): ...@@ -4262,6 +4262,31 @@ class t_dot(unittest.TestCase):
self.assertTrue(tz.shape == nz.shape) self.assertTrue(tz.shape == nz.shape)
self.assertTrue(_approx_eq(nz, tz)) self.assertTrue(_approx_eq(nz, tz))
def test_Op_dims(self):
# _dot is a Dot op instance
_dot = theano.tensor.basic._dot
d0 = scalar()
d1 = vector()
d2 = matrix()
d3 = tensor3()
self.assertRaises(TypeError, _dot, d0, d0)
self.assertRaises(TypeError, _dot, d0, d1)
self.assertRaises(TypeError, _dot, d0, d2)
self.assertRaises(TypeError, _dot, d0, d3)
self.assertRaises(TypeError, _dot, d1, d0)
_dot(d1, d1)
_dot(d1, d2)
self.assertRaises(TypeError, _dot, d1, d3)
self.assertRaises(TypeError, _dot, d2, d0)
_dot(d2, d1)
_dot(d2, d2)
self.assertRaises(TypeError, _dot, d2, d3)
self.assertRaises(TypeError, _dot, d3, d0)
self.assertRaises(TypeError, _dot, d3, d1)
self.assertRaises(TypeError, _dot, d3, d2)
self.assertRaises(TypeError, _dot, d3, d3)
def test_dot_0d_0d(self): def test_dot_0d_0d(self):
self.cmp_dot(1.1, 2.2) self.cmp_dot(1.1, 2.2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论