提交 f470707c authored 作者: Frederic's avatar Frederic

Added explaination for a test following code review.

上级 f58e14cc
...@@ -768,6 +768,13 @@ def test_gemm_opt_vector_stuff(): ...@@ -768,6 +768,13 @@ def test_gemm_opt_vector_stuff():
def test_gemm_unrolled(): def test_gemm_unrolled():
"""This test that the gemm optimizer remove the dot22 that was
present in the graph. Otherwise, this add a gemm, but still
compute the dot22.
This was not always the case in the with this the following code.
"""
batch_size = 100 batch_size = 100
rep_size = 40 rep_size = 40
rng = numpy.random.RandomState([1, 2, 3]) rng = numpy.random.RandomState([1, 2, 3])
...@@ -799,7 +806,8 @@ def test_gemm_unrolled(): ...@@ -799,7 +806,8 @@ def test_gemm_unrolled():
if isinstance(node.op, (theano.tensor.Dot, if isinstance(node.op, (theano.tensor.Dot,
theano.tensor.blas.Dot22, theano.tensor.blas.Dot22,
theano.tensor.blas.Gemm))]) theano.tensor.blas.Gemm))])
# Each num_rounds add 3 dot, but one of them is always the same.
# So the final graph should have 1 + 2* num_rounds dot varient op.
assert nb_dot == num_rounds * 2 + 1, nb_dot assert nb_dot == num_rounds * 2 + 1, nb_dot
unrolled_theano() unrolled_theano()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论