提交 9e12d2c0 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

file theano/misc/check_blas.py

上级 bab1c1b8
...@@ -13,7 +13,7 @@ import time ...@@ -13,7 +13,7 @@ import time
from optparse import OptionParser from optparse import OptionParser
import subprocess import subprocess
import numpy import numpy as np
import theano import theano
import theano.tensor as T import theano.tensor as T
...@@ -47,10 +47,10 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000, ...@@ -47,10 +47,10 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000,
print() print()
print('Numpy config: (used when the Theano flag' print('Numpy config: (used when the Theano flag'
' "blas.ldflags" is empty)') ' "blas.ldflags" is empty)')
numpy.show_config() np.show_config()
print('Numpy dot module:', numpy.dot.__module__) print('Numpy dot module:', np.dot.__module__)
print('Numpy location:', numpy.__file__) print('Numpy location:', np.__file__)
print('Numpy version:', numpy.__version__) print('Numpy version:', np.__version__)
if (theano.config.device.startswith("gpu") or if (theano.config.device.startswith("gpu") or
theano.config.init_gpu_device.startswith("gpu")): theano.config.init_gpu_device.startswith("gpu")):
print('nvcc version:') print('nvcc version:')
...@@ -58,12 +58,12 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000, ...@@ -58,12 +58,12 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000,
"--version")) "--version"))
print() print()
a = theano.shared(numpy.ones((M, N), dtype=theano.config.floatX, a = theano.shared(np.ones((M, N), dtype=theano.config.floatX,
order=order)) order=order))
b = theano.shared(numpy.ones((N, K), dtype=theano.config.floatX, b = theano.shared(np.ones((N, K), dtype=theano.config.floatX,
order=order)) order=order))
c = theano.shared(numpy.ones((M, K), dtype=theano.config.floatX, c = theano.shared(np.ones((M, K), dtype=theano.config.floatX,
order=order)) order=order))
f = theano.function([], updates=[(c, 0.4 * c + .8 * T.dot(a, b))]) f = theano.function([], updates=[(c, 0.4 * c + .8 * T.dot(a, b))])
if any([x.op.__class__.__name__ == 'Gemm' for x in if any([x.op.__class__.__name__ == 'Gemm' for x in
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论