Unverified 提交 ee8c4cf4 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6641 from twiecki/broadcast_sparse_dot

Broadcast sparse dot
......@@ -4008,28 +4008,34 @@ class Dot(gof.op.Op):
"sparse variable as inputs, but the inputs are "
"%s (%s) and %s (%s)." % (x, x.type, y, y.type))
if not x_is_sparse_var:
if x_is_sparse_var:
broadcast_x = (False,) * x.ndim
else:
x = tensor.as_tensor_variable(x)
broadcast_x = x.type.broadcastable
assert y.format in ["csr", "csc"]
if x.ndim not in (1, 2):
raise TypeError(
'theano.sparse.Dot: input 0 (0-indexed) must have ndim of '
'1 or 2, %d given.' % x.ndim)
if not y_is_sparse_var:
if y_is_sparse_var:
broadcast_y = (False,) * y.ndim
else:
y = tensor.as_tensor_variable(y)
broadcast_y = y.type.broadcastable
assert x.format in ["csr", "csc"]
if y.ndim not in (1, 2):
raise TypeError(
'theano.sparse.Dot: input 1 (1-indexed) must have ndim of '
'1 or 2, %d given.' % y.ndim)
if y.ndim == 1 or x.ndim == 1:
bz = (False,)
else:
bz = (False, False)
if len(broadcast_y) == 2:
broadcast_out = broadcast_x[:-1] + broadcast_y[1:]
elif len(broadcast_y) == 1:
broadcast_out = broadcast_x[:-1]
return gof.Apply(self, [x, y], [tensor.tensor(dtype=dtype_out,
broadcastable=bz)])
broadcastable=broadcast_out)])
def perform(self, node, inputs, out):
x, y = inputs
......
......@@ -464,6 +464,23 @@ class SparseInferShapeTester(utt.InferShapeTester):
config.floatX, 3))],
Dot)
def test_dot_broadcast(self):
for x, y in [
(SparseType('csr', 'float32')(), tensor.vector()[:, None]),
(SparseType('csr', 'float32')(), tensor.vector()[None, :]),
(SparseType('csr', 'float32')(), tensor.matrix()),
(tensor.vector()[:, None], SparseType('csr', 'float32')()),
(tensor.vector()[None, :], SparseType('csr', 'float32')()),
(tensor.matrix(), SparseType('csr', 'float32')())]:
sparse_out = theano.dot(x, y)
if isinstance(x, sparse.SparseVariable):
x = tensor.matrix()
if isinstance(y, sparse.SparseVariable):
y = tensor.matrix()
dense_out = tensor.dot(x, y)
assert dense_out.broadcastable == sparse_out.broadcastable
def test_structured_dot(self):
x = SparseType('csc', dtype=config.floatX)()
y = SparseType('csc', dtype=config.floatX)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论