提交 2acc0720 authored 作者: AdeB's avatar AdeB 提交者: Pascal Lamblin

Fix test_blocksparse

上级 438e7b2b
...@@ -2,8 +2,9 @@ from __future__ import absolute_import, print_function, division ...@@ -2,8 +2,9 @@ from __future__ import absolute_import, print_function, division
import theano import theano
from theano import tensor from theano import tensor
from theano.gof.opt import check_stack_trace from theano.gof.opt import check_stack_trace
from theano.tensor.nnet.blocksparse import sparse_block_dot, \ from theano.tensor.nnet.blocksparse import (
sparse_block_gemv_inplace, sparse_block_outer_inplace sparse_block_dot, sparse_block_gemv_inplace, sparse_block_outer_inplace,
sparse_block_gemv, sparse_block_outer)
def test_blocksparse_inplace_gemv_opt(): def test_blocksparse_inplace_gemv_opt():
...@@ -16,12 +17,13 @@ def test_blocksparse_inplace_gemv_opt(): ...@@ -16,12 +17,13 @@ def test_blocksparse_inplace_gemv_opt():
o = sparse_block_dot(W, h, iIdx, b, oIdx) o = sparse_block_dot(W, h, iIdx, b, oIdx)
f = theano.function([W, h, iIdx, b, oIdx], o) f = theano.function([W, h, iIdx, b, oIdx], o)
assert check_stack_trace(f, ops_to_check=sparse_block_gemv_inplace)
if theano.config.mode == "FAST_COMPILE": if theano.config.mode == "FAST_COMPILE":
assert not f.maker.fgraph.toposort()[-1].op.inplace assert not f.maker.fgraph.toposort()[-1].op.inplace
assert check_stack_trace(f, ops_to_check=[sparse_block_gemv])
else: else:
assert f.maker.fgraph.toposort()[-1].op.inplace assert f.maker.fgraph.toposort()[-1].op.inplace
assert check_stack_trace(f, ops_to_check=[sparse_block_gemv_inplace])
def test_blocksparse_inplace_outer_opt(): def test_blocksparse_inplace_outer_opt():
...@@ -33,13 +35,12 @@ def test_blocksparse_inplace_outer_opt(): ...@@ -33,13 +35,12 @@ def test_blocksparse_inplace_outer_opt():
o = sparse_block_dot(W, h, iIdx, b, oIdx) o = sparse_block_dot(W, h, iIdx, b, oIdx)
theano.printing.debugprint(tensor.grad(o.sum(), wrt=W))
f = theano.function([W, h, iIdx, b, oIdx], f = theano.function([W, h, iIdx, b, oIdx],
[o, tensor.grad(o.sum(), wrt=W)]) [o, tensor.grad(o.sum(), wrt=W)])
assert check_stack_trace(f, ops_to_check=sparse_block_outer_inplace)
if theano.config.mode == "FAST_COMPILE": if theano.config.mode == "FAST_COMPILE":
assert not f.maker.fgraph.toposort()[-1].op.inplace assert not f.maker.fgraph.toposort()[-1].op.inplace
assert check_stack_trace(f, ops_to_check=sparse_block_outer)
else: else:
assert f.maker.fgraph.toposort()[-1].op.inplace assert f.maker.fgraph.toposort()[-1].op.inplace
assert check_stack_trace(f, ops_to_check=sparse_block_outer_inplace)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论