提交 0d0012e5 authored 作者: Frederic Bastien's avatar Frederic Bastien

make test_naacl09.py time with debug mode from 3h40 to 26 minutes.

上级 b6be378f
......@@ -501,7 +501,7 @@ def create_realistic(window_size=3,#7,
model = architecture.make(input_size=input_dimension, input_representation_size=token_representation_size, hidden_representation_size=concatenated_representation_size, output_size=output_vocabsize, lr=lr, seed=seed, noise_level=noise_level, qfilter_relscale=qfilter_relscale, mode=compile_mode)
return model
def test_naacl_model(iters_per_unsup=10, iters_per_sup=10,
def test_naacl_model(iters_per_unsup=3, iters_per_sup=3,
optimizer=None, realistic=False):
print "BUILDING MODEL"
import time
......@@ -513,7 +513,7 @@ def test_naacl_model(iters_per_unsup=10, iters_per_sup=10,
else:
m = create(compile_mode=mode)
print 'BUILD took', time.time() - t
print 'BUILD took %.3fs'%(time.time() - t)
prog_str = []
idx_of_node = {}
for i, node in enumerate(m.pretraining_update.maker.env.toposort()):
......@@ -529,30 +529,34 @@ def test_naacl_model(iters_per_unsup=10, iters_per_sup=10,
rng = N.random.RandomState(unittest_tools.fetch_seed(23904))
inputs = [rng.rand(10,9) for i in 1,2,3]
inputs = [rng.rand(10,m.input_size) for i in 1,2,3]
targets = N.asarray([0,3,4,2,3,4,4,2,1,0])
#print inputs
print 'UNSUPERVISED PHASE'
for i in xrange(10):
t = time.time()
for i in xrange(3):
for j in xrange(iters_per_unsup):
m.pretraining_update(*inputs)
s0, s1 = [str(j) for j in m.pretraining_update(*inputs)]
print 'huh?', i, iters_per_unsup, iters_per_unsup * (i+1), s0, s1
if iters_per_unsup == 10:
assert s0.startswith('0.403044')
assert s1.startswith('0.074898')
if iters_per_unsup == 3:
assert s0.startswith('0.927793')#'0.403044')
assert s1.startswith('0.068035')#'0.074898')
print 'UNSUPERVISED took %.3fs'%(time.time() - t)
print 'FINETUNING GRAPH'
print 'SUPERVISED PHASE COSTS (%s)'%optimizer
for i in xrange(10):
t = time.time()
for i in xrange(3):
for j in xrange(iters_per_unsup):
m.finetuning_update(*(inputs + [targets]))
s0 = str(m.finetuning_update(*(inputs + [targets])))
print iters_per_sup * (i+1), s0
if iters_per_sup == 10:
s0f = float(s0)
assert 15.6510 < s0f and s0f < 15.6512
assert 19.7042 < s0f and s0f < 19.7043
print 'SUPERVISED took %.3fs'%( time.time() - t)
def jtest_main():
from theano import gof
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论