提交 8741d36a authored 作者: Frederic's avatar Frederic

Add sparse.dot infer_shape test.

上级 d0d41c3b
...@@ -598,7 +598,18 @@ class DotTests(unittest.TestCase): ...@@ -598,7 +598,18 @@ class DotTests(unittest.TestCase):
vx = getattr(self,'x_'+x_f).astype(d1) vx = getattr(self,'x_'+x_f).astype(d1)
vy = getattr(self,'y_'+y_f).astype(d2) vy = getattr(self,'y_'+y_f).astype(d2)
assert abs(f_a(vx, vy) - f_b(vx, vy)).max() < 1e-4 assert _allclose(f_a(vx, vy), f_b(vx, vy).toarray())
# Test infer_shape
f_a = theano.function([x, y], theano.sparse.dot(x, y).shape)
f_b = lambda x, y: (x * y).shape
assert numpy.all(f_a(vx, vy) == f_b(vx, vy))
topo = f_a.maker.env.toposort()
if theano.config.mode!='FAST_COMPILE':
nb = 0
else:
nb = 1
assert sum([isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo]) == nb
class UsmmTests(unittest.TestCase): class UsmmTests(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论