提交 b094c0bb authored 作者: --global's avatar --global

Merge test_log_softmax and test_log_softmax_opt

上级 b09557a6
...@@ -437,7 +437,9 @@ def test_pooling_opt(): ...@@ -437,7 +437,9 @@ def test_pooling_opt():
def test_log_softmax(): def test_log_softmax():
if not cuda.dnn.dnn_available(): # This is a test for an optimization that depends on CuDNN v3 or
# more recent. Don't test if the CuDNN version is too old.
if not cuda.dnn.dnn_available() or cuda.dnn.version() < (3000, 3000):
raise SkipTest(cuda.dnn.dnn_available.msg) raise SkipTest(cuda.dnn.dnn_available.msg)
x = T.ftensor4() x = T.ftensor4()
...@@ -446,6 +448,12 @@ def test_log_softmax(): ...@@ -446,6 +448,12 @@ def test_log_softmax():
f = theano.function([x], log_out, mode=mode_with_gpu) f = theano.function([x], log_out, mode=mode_with_gpu)
# Ensure that the optimization has been applied
dnn_softmax_nodes = [n for n in f.maker.fgraph.toposort() if
isinstance(n.op, cuda.dnn.GpuDnnSoftmax)]
assert len(dnn_softmax_nodes) == 1
assert dnn_softmax_nodes[0].op.algo == "log"
# Ensure that the output of the function is valid # Ensure that the output of the function is valid
input_shapes = [(3, 4, 5, 6), input_shapes = [(3, 4, 5, 6),
(1025, 2, 3, 4), (1025, 2, 3, 4),
...@@ -467,26 +475,6 @@ def test_log_softmax(): ...@@ -467,26 +475,6 @@ def test_log_softmax():
utt.assert_allclose(out, expected_out) utt.assert_allclose(out, expected_out)
def test_log_softmax_opt():
# This is a test for an optimization that depends on CuDNN v3 or
# more recent. Don't test if the CuDNN version is too old.
if not cuda.dnn.dnn_available() or cuda.dnn.version() < (3000, 3000):
raise SkipTest(cuda.dnn.dnn_available.msg)
x = T.ftensor4()
softmax_out = dnn.GpuDnnSoftmax('bc01', 'accurate', 'channel')(x)
log_out = T.log(T.as_tensor_variable(softmax_out))
f = theano.function([x], log_out, mode=mode_with_gpu)
dnn_softmax_nodes = [n for n in f.maker.fgraph.toposort() if
isinstance(n.op, cuda.dnn.GpuDnnSoftmax)]
# Ensure that the optimization has been applied
assert len(dnn_softmax_nodes) == 1
assert dnn_softmax_nodes[0].op.algo == "log"
def test_dnn_tag(): def test_dnn_tag():
""" """
Test that if cudnn isn't avail we crash and that if it is avail, we use it. Test that if cudnn isn't avail we crash and that if it is avail, we use it.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论