提交 b29c6fbd authored 作者: James Bergstra's avatar James Bergstra

improved sparse test_basic of csc speed to be more robust

上级 8697ba45
...@@ -291,7 +291,8 @@ class test_structureddot(unittest.TestCase): ...@@ -291,7 +291,8 @@ class test_structureddot(unittest.TestCase):
a = SparseType('csc', dtype=sparse_dtype)() a = SparseType('csc', dtype=sparse_dtype)()
b = tensor.matrix(dtype=dense_dtype) b = tensor.matrix(dtype=dense_dtype)
d = theano.dot(a,b) d = theano.dot(a,b)
f = theano.function([a,b], d, mode='FAST_RUN') #f = theano.function([a,b], theano.Out(d, borrow=True), mode='PROFILE_MODE')
f = theano.function([a,b], theano.Out(d, borrow=True), mode='FAST_RUN')
# technically we could be using DEBUG MODE to verify internal problems. # technically we could be using DEBUG MODE to verify internal problems.
# in fact, if this test fails for correctness, then it would be good to use DEBUG_MODE # in fact, if this test fails for correctness, then it would be good to use DEBUG_MODE
...@@ -310,18 +311,25 @@ class test_structureddot(unittest.TestCase): ...@@ -310,18 +311,25 @@ class test_structureddot(unittest.TestCase):
]: ]:
spmat = sp.csc_matrix(random_lil((M,N), sparse_dtype, nnz)) spmat = sp.csc_matrix(random_lil((M,N), sparse_dtype, nnz))
mat = numpy.asarray(numpy.random.randn(N,K), dense_dtype) mat = numpy.asarray(numpy.random.randn(N,K), dense_dtype)
t0 = time.time() theano_times = []
theano_result = f(spmat, mat) scipy_times = []
t1 = time.time() for i in xrange(5):
scipy_result = spmat * mat t0 = time.time()
t2 = time.time() theano_result = f(spmat, mat)
t1 = time.time()
scipy_result = spmat * mat
t2 = time.time()
theano_time = t1-t0 theano_times.append(t1-t0)
scipy_time = t2-t1 scipy_times.append(t2-t1)
#print theano_result
#print scipy_result theano_time = numpy.min(theano_times)
print 'theano took', theano_time, scipy_time = numpy.min(scipy_times)
print 'scipy took', scipy_time
speedup = scipy_time / theano_time
print scipy_times
print theano_times
print 'M=%(M)s N=%(N)s K=%(K)s nnz=%(nnz)s theano_time=%(theano_time)s speedup=%(speedup)s' % locals()
# fail if Theano is slower than scipy by more than a certain amount # fail if Theano is slower than scipy by more than a certain amount
overhead_tol = 0.003 # seconds overall overhead_tol = 0.003 # seconds overall
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论