提交 1d17699f authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

Dot op only supports matrix/vector args; other ndims must go through tensor.dot…

Dot op only supports matrix/vector args; other ndims must go through tensor.dot to ensure optimizations can be applied
上级 67898e2e
......@@ -6899,19 +6899,25 @@ class Dot(Op):
if len(inputs) != 2:
raise TypeError(
'theanor.tensor.Dot: 2 arguments required, %d given ' %
'theano.tensor.Dot: 2 arguments required, %d given ' %
len(inputs))
if inputs[0].ndim not in (1, 2):
raise TypeError(
'theano.tensor.Dot: input 0 (0-indexed) must have ndim of '
'1 or 2, %d given. Consider calling theano.tensor.dot '
'instead.' % inputs[0].ndim)
if inputs[1].ndim not in (1, 2):
raise TypeError(
'theano.tensor.Dot: input 1 (0-indexed) must have ndim of '
'1 or 2, %d given. Consider calling theano.tensor.dot '
'instead.' % inputs[1].ndim)
i_broadcastables = [input.type.broadcastable for input in inputs]
bx, by = i_broadcastables
if len(bx) == 0: # x is a scalar
bz = by
else:
if len(by) >= 2: # y is a matrix or tensor
bz = bx[:-1] + by[:-2] + by[-1:]
elif len(by) == 1: # y is vector
bz = bx[:-1]
else: # y is a scalar
bz = bx
if len(by) == 2: # y is a matrix
bz = bx[:-1] + by[-1:]
elif len(by) == 1: # y is vector
bz = bx[:-1]
i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(scal.upcast(*i_dtypes), bz)]
......@@ -6924,13 +6930,7 @@ class Dot(Op):
# the asarray is here because dot between two vectors
# gives a numpy float object but we need to return a 0d
# ndarray
if x.ndim == 0 or y.ndim == 0:
z[0] = numpy.asarray(x * y)
elif x.ndim > 2 or y.ndim > 2:
axes = [x.ndim - 1, y.ndim - 2]
z[0] = numpy.asarray(numpy.tensordot(x, y, axes))
else:
z[0] = numpy.asarray(numpy.dot(x, y))
z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to
# add that
......@@ -6951,21 +6951,11 @@ class Dot(Op):
gz, = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
#grad is scalar, so x is scalar or vector and y is same as x
#grad is scalar, so x is vector and y is vector
if gdim == 0:
xgrad = gz * y
ygrad = gz * x
#x is scalar, y is not scalar, grad.shape == y.shape
elif xdim == 0:
xgrad = (gz * y).sum()
ygrad = x * gz
#x is not scalar, y is scalar, grad.shape == x.shape
elif ydim == 0:
xgrad = y * gz
ygrad = (gz * x).sum()
#x is vector, y is matrix, grad is vector
elif xdim == 1 and ydim == 2:
xgrad = dot(gz, y.T)
......@@ -6981,36 +6971,6 @@ class Dot(Op):
xgrad = dot(gz, y.T)
ygrad = dot(x.T, gz)
# x or y is tensor, grad is tensor
#
# the output grad has the same dim as the dot product output, namely
# x.shape[:-1] + y.shape[:-2] + y.shape[-1:]. To get the grad
# wrt x or y, a tensordot is used to sum out non-compatible dims.
#
# for grad wrt x:
# grad is a tensordot of the output grad and y, summing out all but
# the second-to-last dim of y. If y is a vector, no sum is taken.
#
# for grad wrt y:
# grad is a tensordot of the output grad and x, summing out
# all but the last dim of x. If y is not a vector, the tensordot is
# transposed so that its last dim becomes its second-to-last.
else:
gy_axes = range(xdim - 1, gdim)
if ydim != 1:
y_axes = [ax for ax in range(ydim) if ax != ydim - 2]
else:
y_axes = []
xgrad = tensordot(gz, y, [gy_axes, y_axes])
gx_axes = y_axes = range(xdim - 1)
if ydim != 1:
t_dims = range(ydim)
t_dims[-2], t_dims[-1] = t_dims[-1], t_dims[-2]
ygrad = tensordot(gz, x, [gx_axes, y_axes]).transpose(t_dims)
else:
ygrad = tensordot(gz, x, [gx_axes, y_axes])
rval = xgrad, ygrad
for elem in rval:
......@@ -7080,27 +7040,18 @@ class Dot(Op):
xshp, yshp = shapes
x, y = node.inputs
#scalar / scalar
if x.ndim == 0 and y.ndim == 0:
return [()]
#not scalar / scalar
if x.ndim != 0 and y.ndim == 0:
return [xshp]
#scalar / not scalar
if x.ndim == 0 and y.ndim != 0:
return [yshp]
#vector / vector
# vector / vector
if x.ndim == 1 and y.ndim == 1:
return [()]
#tensor / vector
if x.ndim > 1 and y.ndim == 1:
# matrix / vector
if x.ndim == 2 and y.ndim == 1:
return [xshp[:-1]]
#vector / tensor
if x.ndim == 1 and y.ndim > 1:
return [yshp[:-2] + yshp[-1:]]
#tensor / tensor
if x.ndim > 1 and y.ndim > 1:
return [xshp[:-1] + yshp[:-2] + yshp[-1:]]
# vector / matrix
if x.ndim == 1 and y.ndim == 2:
return [yshp[-1:]]
# matrix / matrix
if x.ndim == 2 and y.ndim == 2:
return [xshp[:-1] + yshp[-1:]]
raise NotImplementedError()
def __str__(self):
......@@ -7268,7 +7219,7 @@ def tensordot(a, b, axes = 2):
a_reshaped = a.reshape((a_shape_0, -1), ndim = 2)
b_reshaped = b.reshape((b_shape_0, -1), ndim = 2)
return dot(a_reshaped, b_reshaped).reshape(outshape, outndim)
return _dot(a_reshaped, b_reshaped).reshape(outshape, outndim)
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
......
......@@ -4262,6 +4262,31 @@ class t_dot(unittest.TestCase):
self.assertTrue(tz.shape == nz.shape)
self.assertTrue(_approx_eq(nz, tz))
def test_Op_dims(self):
# _dot is a Dot op instance
_dot = theano.tensor.basic._dot
d0 = scalar()
d1 = vector()
d2 = matrix()
d3 = tensor3()
self.assertRaises(TypeError, _dot, d0, d0)
self.assertRaises(TypeError, _dot, d0, d1)
self.assertRaises(TypeError, _dot, d0, d2)
self.assertRaises(TypeError, _dot, d0, d3)
self.assertRaises(TypeError, _dot, d1, d0)
_dot(d1, d1)
_dot(d1, d2)
self.assertRaises(TypeError, _dot, d1, d3)
self.assertRaises(TypeError, _dot, d2, d0)
_dot(d2, d1)
_dot(d2, d2)
self.assertRaises(TypeError, _dot, d2, d3)
self.assertRaises(TypeError, _dot, d3, d0)
self.assertRaises(TypeError, _dot, d3, d1)
self.assertRaises(TypeError, _dot, d3, d2)
self.assertRaises(TypeError, _dot, d3, d3)
def test_dot_0d_0d(self):
self.cmp_dot(1.1, 2.2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论