提交 a677494a authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fix test

上级 52b6a213
...@@ -8,7 +8,8 @@ from theano.sandbox import multinomial ...@@ -8,7 +8,8 @@ from theano.sandbox import multinomial
from theano.compile.mode import get_default_mode, predefined_linkers from theano.compile.mode import get_default_mode, predefined_linkers
import theano.sandbox.cuda as cuda import theano.sandbox.cuda as cuda
import theano.tests.unittest_tools as utt import theano.tests.unittest_tools as utt
import cPickle import six.moves.cPickle as pickle
import os
def get_mode(gpu): def get_mode(gpu):
mode = get_default_mode() mode = get_default_mode()
...@@ -77,9 +78,10 @@ def test_n_samples_compatibility(): ...@@ -77,9 +78,10 @@ def test_n_samples_compatibility():
pvals = T.exp(X) pvals = T.exp(X)
pvals = pvals / pvals.sum(axis=1, keepdims=True) pvals = pvals / pvals.sum(axis=1, keepdims=True)
samples = th_rng.multinomial(pvals=pvals) samples = th_rng.multinomial(pvals=pvals)
cPickle.dump([X, samples], open("multinomial_test_graph.pkl", "w")) pickle.dump([X, samples], open("multinomial_test_graph.pkl", "w"))
""" """
X, samples = cPickle.load(open("multinomial_test_graph.pkl")) folder = os.path.dirname(os.path.abspath(__file__))
X, samples = pickle.load(open(os.path.join(folder, "multinomial_test_graph.pkl")))
f = theano.function([X], samples) f = theano.function([X], samples)
res = f(numpy.random.randn(20,10)) res = f(numpy.random.randn(20,10))
assert numpy.all(res.sum(axis=1) == 1) assert numpy.all(res.sum(axis=1) == 1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论