提交 289c48cf authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Thomas Wiecki

Fix broadcasting in sparse dot

上级 8430dd67
......@@ -4008,28 +4008,34 @@ class Dot(gof.op.Op):
"sparse variable as inputs, but the inputs are "
"%s (%s) and %s (%s)." % (x, x.type, y, y.type))
if not x_is_sparse_var:
if x_is_sparse_var:
broadcast_x = (False,) * x.ndim
else:
x = tensor.as_tensor_variable(x)
broadcast_x = x.type.broadcastable
assert y.format in ["csr", "csc"]
if x.ndim not in (1, 2):
raise TypeError(
'theano.sparse.Dot: input 0 (0-indexed) must have ndim of '
'1 or 2, %d given.' % x.ndim)
if not y_is_sparse_var:
if y_is_sparse_var:
broadcast_y = (False,) * y.ndim
else:
y = tensor.as_tensor_variable(y)
broadcast_y = y.type.broadcastable
assert x.format in ["csr", "csc"]
if y.ndim not in (1, 2):
raise TypeError(
'theano.sparse.Dot: input 1 (1-indexed) must have ndim of '
'1 or 2, %d given.' % y.ndim)
if y.ndim == 1 or x.ndim == 1:
bz = (False,)
else:
bz = (False, False)
if len(broadcast_y) == 2:
broadcast_out = broadcast_x[:-1] + broadcast_y[1:]
elif len(broadcast_y) == 1:
broadcast_out = broadcast_x[:-1]
return gof.Apply(self, [x, y], [tensor.tensor(dtype=dtype_out,
broadcastable=bz)])
broadcastable=broadcast_out)])
def perform(self, node, inputs, out):
x, y = inputs
......
......@@ -464,6 +464,12 @@ class SparseInferShapeTester(utt.InferShapeTester):
config.floatX, 3))],
Dot)
def test_dot_broadcast(self):
A = sp.matrix('csr')
b = tensor.vector()
bc = sp.dot(A, b[:, None]).broadcastable
assert bc == (False, True)
def test_structured_dot(self):
x = SparseType('csc', dtype=config.floatX)()
y = SparseType('csc', dtype=config.floatX)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论