提交 a677494a authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fix test

上级 52b6a213
......@@ -8,7 +8,8 @@ from theano.sandbox import multinomial
from theano.compile.mode import get_default_mode, predefined_linkers
import theano.sandbox.cuda as cuda
import theano.tests.unittest_tools as utt
import cPickle
import six.moves.cPickle as pickle
import os
def get_mode(gpu):
mode = get_default_mode()
......@@ -77,9 +78,10 @@ def test_n_samples_compatibility():
pvals = T.exp(X)
pvals = pvals / pvals.sum(axis=1, keepdims=True)
samples = th_rng.multinomial(pvals=pvals)
cPickle.dump([X, samples], open("multinomial_test_graph.pkl", "w"))
pickle.dump([X, samples], open("multinomial_test_graph.pkl", "w"))
"""
X, samples = cPickle.load(open("multinomial_test_graph.pkl"))
folder = os.path.dirname(os.path.abspath(__file__))
X, samples = pickle.load(open(os.path.join(folder, "multinomial_test_graph.pkl")))
f = theano.function([X], samples)
res = f(numpy.random.randn(20,10))
assert numpy.all(res.sum(axis=1) == 1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论