提交 2acc0720 authored 作者: AdeB's avatar AdeB 提交者: Pascal Lamblin

Fix test_blocksparse

上级 438e7b2b
......@@ -2,8 +2,9 @@ from __future__ import absolute_import, print_function, division
import theano
from theano import tensor
from theano.gof.opt import check_stack_trace
from theano.tensor.nnet.blocksparse import sparse_block_dot, \
sparse_block_gemv_inplace, sparse_block_outer_inplace
from theano.tensor.nnet.blocksparse import (
sparse_block_dot, sparse_block_gemv_inplace, sparse_block_outer_inplace,
sparse_block_gemv, sparse_block_outer)
def test_blocksparse_inplace_gemv_opt():
......@@ -16,12 +17,13 @@ def test_blocksparse_inplace_gemv_opt():
o = sparse_block_dot(W, h, iIdx, b, oIdx)
f = theano.function([W, h, iIdx, b, oIdx], o)
assert check_stack_trace(f, ops_to_check=sparse_block_gemv_inplace)
if theano.config.mode == "FAST_COMPILE":
assert not f.maker.fgraph.toposort()[-1].op.inplace
assert check_stack_trace(f, ops_to_check=[sparse_block_gemv])
else:
assert f.maker.fgraph.toposort()[-1].op.inplace
assert check_stack_trace(f, ops_to_check=[sparse_block_gemv_inplace])
def test_blocksparse_inplace_outer_opt():
......@@ -33,13 +35,12 @@ def test_blocksparse_inplace_outer_opt():
o = sparse_block_dot(W, h, iIdx, b, oIdx)
theano.printing.debugprint(tensor.grad(o.sum(), wrt=W))
f = theano.function([W, h, iIdx, b, oIdx],
[o, tensor.grad(o.sum(), wrt=W)])
assert check_stack_trace(f, ops_to_check=sparse_block_outer_inplace)
if theano.config.mode == "FAST_COMPILE":
assert not f.maker.fgraph.toposort()[-1].op.inplace
assert check_stack_trace(f, ops_to_check=sparse_block_outer)
else:
assert f.maker.fgraph.toposort()[-1].op.inplace
assert check_stack_trace(f, ops_to_check=sparse_block_outer_inplace)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论