提交 247618cc authored 作者: David Warde-Farley's avatar David Warde-Farley

Tests for StructuredDot and gradients.

上级 0e226ff7
...@@ -363,6 +363,22 @@ class test_structureddot(unittest.TestCase): ...@@ -363,6 +363,22 @@ class test_structureddot(unittest.TestCase):
utt.verify_grad(buildgraph, utt.verify_grad(buildgraph,
[spmat.data, mat]) [spmat.data, mat])
def test_infer_shape_csr_csc_grad(self):
for sparsetype in ('csr', 'csc'):
a = SparseType(sparsetype, dtype=config.floatX)()
b = SparseType(sparsetype, dtype=config.floatX)()
grads = tensor.grad(dense_from_sparse(structured_dot(a, b)).sum(),
[a, b])
f = theano.function([a, b], [g.shape for g in grads])
topo = f.maker.env.toposort()
assert not any(isinstance(t, self.__class__) for t in topo)
call = getattr(sp, sparsetype + '_matrix')
x = call(random_lil((500, 300), config.floatX, 10))
y = call(random_lil((300, 400), config.floatX, 5))
out1, out2 = f(x, y)
assert numpy.all(out1 == x.shape)
assert numpy.all(out2 == y.shape)
def test_upcast(self): def test_upcast(self):
typenames = ('float32', 'int64', 'int8', 'int32', typenames = ('float32', 'int64', 'int8', 'int32',
...@@ -553,6 +569,16 @@ class test_structureddot(unittest.TestCase): ...@@ -553,6 +569,16 @@ class test_structureddot(unittest.TestCase):
self.assertFalse(theano_time > overhead_rtol * scipy_time + self.assertFalse(theano_time > overhead_rtol * scipy_time +
overhead_tol) overhead_tol)
def test_infer_shape(self):
a = SparseType('csc', dtype=config.floatX)()
b = SparseType('csc', dtype=config.floatX)()
f = theano.function([a, b], structured_dot(a, b).shape)
topo = f.maker.env.toposort()
assert not any(isinstance(t, self.__class__) for t in topo)
x = sp.csc_matrix((4, 5), dtype=config.floatX)
y = sp.csc_matrix((5, 3), dtype=config.floatX)
assert numpy.all(f(x, y) == numpy.array((4, 3)))
class DotTests(unittest.TestCase): class DotTests(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论