提交 579b1594 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix ignore_error=True by using the requested batch_size and print the time…

fix ignore_error=True by using the requested batch_size and print the time executed and print the result at each iter.
上级 d2264af6
...@@ -283,24 +283,29 @@ def run_conv_nnet2_classif(shared_fn, isize, ksize, n_batch=60, n_iter=25): ...@@ -283,24 +283,29 @@ def run_conv_nnet2_classif(shared_fn, isize, ksize, n_batch=60, n_iter=25):
yval = numpy.asarray(numpy.random.rand(n_batch,n_out), dtype='float32') yval = numpy.asarray(numpy.random.rand(n_batch,n_out), dtype='float32')
lr = numpy.asarray(0.01, dtype='float32') lr = numpy.asarray(0.01, dtype='float32')
rvals=numpy.zeros(n_iter)
t0 = time.time()
for i in xrange(n_iter): for i in xrange(n_iter):
rval = train(xval, yval, lr) rvals[i] = train(xval, yval, lr)[0]
if i % 10 == 0: t1 = time.time()
print 'rval', rval
print_mode(mode) print_mode(mode)
return rval return rvals, t1-t0
def run_test_conv_nnet2_classif(seed, isize, ksize, bsize, ignore_error=False): def run_test_conv_nnet2_classif(seed, isize, ksize, bsize, ignore_error=False):
if ignore_error: if ignore_error:
numpy.random.seed(seed) numpy.random.seed(seed)
rval_gpu = run_conv_nnet2_classif(tcn.shared_constructor, isize, ksize) rval_gpu, t = run_conv_nnet2_classif(tcn.shared_constructor, isize, ksize, bsize)
return return
numpy.random.seed(seed) numpy.random.seed(seed)
rval_cpu = run_conv_nnet2_classif(shared, isize, ksize, bsize) rval_cpu, tc = run_conv_nnet2_classif(shared, isize, ksize, bsize)
numpy.random.seed(seed) numpy.random.seed(seed)
rval_gpu = run_conv_nnet2_classif(tcn.shared_constructor, isize, ksize, bsize) rval_gpu, tg = run_conv_nnet2_classif(tcn.shared_constructor, isize, ksize, bsize)
assert numpy.allclose(rval_cpu, rval_gpu,rtol=1e-4,atol=1e-6) print "cpu:", rval_cpu
print "gpu:", rval_gpu
print "abs diff:", numpy.absolute(rval_gpu-rval_cpu)
print "time cpu: %f, time gpu: %f, speed up %f"%(tc, tg, tc/tg)
assert numpy.allclose(rval_cpu[:2], rval_gpu[:2],rtol=1e-4,atol=1e-6)
def test_lenet_28(): #MNIST def test_lenet_28(): #MNIST
run_test_conv_nnet2_classif(23485, 28, 5) run_test_conv_nnet2_classif(23485, 28, 5)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论