提交 bddb97fd authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Thomas Wiecki

Test more cases and make them run

上级 289c48cf
...@@ -465,10 +465,21 @@ class SparseInferShapeTester(utt.InferShapeTester): ...@@ -465,10 +465,21 @@ class SparseInferShapeTester(utt.InferShapeTester):
Dot) Dot)
def test_dot_broadcast(self): def test_dot_broadcast(self):
A = sp.matrix('csr') for x, y in [
b = tensor.vector() (SparseType('csr', 'float32')(), tensor.vector()[:, None]),
bc = sp.dot(A, b[:, None]).broadcastable (SparseType('csr', 'float32')(), tensor.vector()[None, :]),
assert bc == (False, True) (SparseType('csr', 'float32')(), tensor.matrix()),
(tensor.vector()[:, None], SparseType('csr', 'float32')()),
(tensor.vector()[None, :], SparseType('csr', 'float32')()),
(tensor.matrix(), SparseType('csr', 'float32')())]:
sparse_out = theano.dot(x, y)
if isinstance(x, sparse.SparseVariable):
x = tensor.matrix()
if isinstance(y, sparse.SparseVariable):
y = tensor.matrix()
dense_out = tensor.dot(x, y)
assert dense_out.broadcastable == sparse_out.broadcastable
def test_structured_dot(self): def test_structured_dot(self):
x = SparseType('csc', dtype=config.floatX)() x = SparseType('csc', dtype=config.floatX)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论