提交 f470707c authored 作者: Frederic's avatar Frederic

Added explaination for a test following code review.

上级 f58e14cc
......@@ -768,6 +768,13 @@ def test_gemm_opt_vector_stuff():
def test_gemm_unrolled():
"""This test that the gemm optimizer remove the dot22 that was
present in the graph. Otherwise, this add a gemm, but still
compute the dot22.
This was not always the case in the with this the following code.
"""
batch_size = 100
rep_size = 40
rng = numpy.random.RandomState([1, 2, 3])
......@@ -799,7 +806,8 @@ def test_gemm_unrolled():
if isinstance(node.op, (theano.tensor.Dot,
theano.tensor.blas.Dot22,
theano.tensor.blas.Gemm))])
# Each num_rounds add 3 dot, but one of them is always the same.
# So the final graph should have 1 + 2* num_rounds dot varient op.
assert nb_dot == num_rounds * 2 + 1, nb_dot
unrolled_theano()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论