提交 580d0b6c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #165 from nouiz/fix

Fix Looks fine, mostly it is pep8 compatibility fixes. Razvan
...@@ -27,7 +27,8 @@ class History: ...@@ -27,7 +27,8 @@ class History:
def on_attach(self, env): def on_attach(self, env):
if hasattr(env, 'checkpoint') or hasattr(env, 'revert'): if hasattr(env, 'checkpoint') or hasattr(env, 'revert'):
raise AlreadyThere("History feature is already present or in conflict with another plugin.") raise AlreadyThere("History feature is already present or in"
" conflict with another plugin.")
self.history[env] = [] self.history[env] = []
env.checkpoint = lambda: len(self.history[env]) env.checkpoint = lambda: len(self.history[env])
env.revert = partial(self.revert, env) env.revert = partial(self.revert, env)
...@@ -41,7 +42,8 @@ class History: ...@@ -41,7 +42,8 @@ class History:
if self.history[env] is None: if self.history[env] is None:
return return
h = self.history[env] h = self.history[env]
h.append(lambda: env.change_input(node, i, r, reason=("Revert", reason))) h.append(lambda: env.change_input(node, i, r,
reason=("Revert", reason)))
def revert(self, env, checkpoint): def revert(self, env, checkpoint):
""" """
...@@ -61,8 +63,10 @@ class Validator: ...@@ -61,8 +63,10 @@ class Validator:
def on_attach(self, env): def on_attach(self, env):
if hasattr(env, 'validate'): if hasattr(env, 'validate'):
raise AlreadyThere("Validator feature is already present or in conflict with another plugin.") raise AlreadyThere("Validator feature is already present or in"
" conflict with another plugin.")
env.validate = lambda: env.execute_callbacks('validate') env.validate = lambda: env.execute_callbacks('validate')
def consistent(): def consistent():
try: try:
env.validate() env.validate()
...@@ -83,7 +87,8 @@ class ReplaceValidate(History, Validator): ...@@ -83,7 +87,8 @@ class ReplaceValidate(History, Validator):
Validator.on_attach(self, env) Validator.on_attach(self, env)
for attr in ('replace_validate', 'replace_all_validate'): for attr in ('replace_validate', 'replace_all_validate'):
if hasattr(env, attr): if hasattr(env, attr):
raise AlreadyThere("ReplaceValidate feature is already present or in conflict with another plugin.") raise AlreadyThere("ReplaceValidate feature is already present"
" or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env) env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env) env.replace_all_validate = partial(self.replace_all_validate, env)
...@@ -101,10 +106,16 @@ class ReplaceValidate(History, Validator): ...@@ -101,10 +106,16 @@ class ReplaceValidate(History, Validator):
for r, new_r in replacements: for r, new_r in replacements:
try: try:
env.replace(r, new_r, reason=reason) env.replace(r, new_r, reason=reason)
print reason
except Exception, e: except Exception, e:
if 'The type of the replacement must be the same' not in str(e) and 'does not belong to this Env' not in str(e): if ('The type of the replacement must be the same' not in
print >> sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e, reason str(e) and 'does not belong to this Env' not in str(e)):
env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling) out = sys.stderr
print >> out, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>",
print >> out, type(e), e, reason
# this might fail if the error is in a listener:
# (env.replace kinda needs better internal error handling)
env.revert(chk)
raise raise
try: try:
env.validate() env.validate()
...@@ -122,14 +133,16 @@ class NodeFinder(dict, Bookkeeper): ...@@ -122,14 +133,16 @@ class NodeFinder(dict, Bookkeeper):
if self.env is not None: if self.env is not None:
raise Exception("A NodeFinder instance can only serve one Env.") raise Exception("A NodeFinder instance can only serve one Env.")
if hasattr(env, 'get_nodes'): if hasattr(env, 'get_nodes'):
raise AlreadyThere("NodeFinder is already present or in conflict with another plugin.") raise AlreadyThere("NodeFinder is already present or in conflict"
" with another plugin.")
self.env = env self.env = env
env.get_nodes = partial(self.query, env) env.get_nodes = partial(self.query, env)
Bookkeeper.on_attach(self, env) Bookkeeper.on_attach(self, env)
def on_detach(self, env): def on_detach(self, env):
if self.env is not env: if self.env is not env:
raise Exception("This NodeFinder instance was not attached to the provided env.") raise Exception("This NodeFinder instance was not attached to the"
" provided env.")
self.env = None self.env = None
del env.get_nodes del env.get_nodes
Bookkeeper.on_detach(self, env) Bookkeeper.on_detach(self, env)
...@@ -137,7 +150,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -137,7 +150,7 @@ class NodeFinder(dict, Bookkeeper):
def on_import(self, env, node): def on_import(self, env, node):
try: try:
self.setdefault(node.op, []).append(node) self.setdefault(node.op, []).append(node)
except TypeError: #node.op is unhashable except TypeError: # node.op is unhashable
return return
except Exception, e: except Exception, e:
print >> sys.stderr, 'OFFENDING node', type(node), type(node.op) print >> sys.stderr, 'OFFENDING node', type(node), type(node.op)
...@@ -150,7 +163,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -150,7 +163,7 @@ class NodeFinder(dict, Bookkeeper):
def on_prune(self, env, node): def on_prune(self, env, node):
try: try:
nodes = self[node.op] nodes = self[node.op]
except TypeError: #node.op is unhashable except TypeError: # node.op is unhashable
return return
nodes.remove(node) nodes.remove(node)
if not nodes: if not nodes:
...@@ -160,14 +173,15 @@ class NodeFinder(dict, Bookkeeper): ...@@ -160,14 +173,15 @@ class NodeFinder(dict, Bookkeeper):
try: try:
all = self.get(op, []) all = self.get(op, [])
except TypeError: except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op) raise TypeError("%s in unhashable and cannot be queried by the"
" optimizer" % op)
all = list(all) all = list(all)
return all return all
class PrintListener(object): class PrintListener(object):
def __init__(self, active = True): def __init__(self, active=True):
self.active = active self.active = active
def on_attach(self, env): def on_attach(self, env):
...@@ -188,7 +202,8 @@ class PrintListener(object): ...@@ -188,7 +202,8 @@ class PrintListener(object):
def on_change_input(self, env, node, i, r, new_r, reason=None): def on_change_input(self, env, node, i, r, new_r, reason=None):
if self.active: if self.active:
print "-- changing (%s.inputs[%s]) from %s to %s" % (node, i, r, new_r) print "-- changing (%s.inputs[%s]) from %s to %s" % (
node, i, r, new_r)
class PreserveNames: class PreserveNames:
......
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
#A,B,C matrix #A,B,C matrix
#a,b scalar #a,b scalar
s=""" s = """
result for shapes=(2000,2000) and iters=100 result for shapes=(2000,2000) and iters=100
GTX 470 7.22s GTX 470 7.22s
GTX 285, 6.84s GTX 285, 6.84s
GTX 480 5.83s GTX 480 5.83s
""" """
import os, sys, time import os
import sys
import time
import numpy import numpy
import theano import theano
...@@ -20,61 +22,67 @@ import theano.tensor as T ...@@ -20,61 +22,67 @@ import theano.tensor as T
from theano.gof.python25 import any from theano.gof.python25 import any
shapes=(2000,2000) shapes = (2000, 2000)
iters = 10 iters = 10
def execute(execute=True, verbose=True): def execute(execute=True, verbose=True):
a=theano.shared(numpy.ones(shapes, dtype=theano.config.floatX)) a = theano.shared(numpy.ones(shapes, dtype=theano.config.floatX))
b=theano.shared(numpy.ones(shapes, dtype=theano.config.floatX)) b = theano.shared(numpy.ones(shapes, dtype=theano.config.floatX))
c=theano.shared(numpy.ones(shapes, dtype=theano.config.floatX)) c = theano.shared(numpy.ones(shapes, dtype=theano.config.floatX))
f=theano.function([],updates={c:0.4*c+.8*T.dot(a,b)}) f = theano.function([], updates={c: 0.4 * c + .8 * T.dot(a, b)})
if verbose: if verbose:
print 'Some theano flags:' print 'Some theano flags:'
print ' blas.ldflags=',theano.config.blas.ldflags print ' blas.ldflags=', theano.config.blas.ldflags
print ' compiledir=',theano.config.compiledir print ' compiledir=', theano.config.compiledir
print ' floatX=',theano.config.floatX print ' floatX=', theano.config.floatX
print 'Some env flags:' print 'Some env flags:'
print ' MKL_NUM_THREADS=',os.getenv('MKL_NUM_THREADS') print ' MKL_NUM_THREADS=', os.getenv('MKL_NUM_THREADS')
print ' OMP_NUM_THREADS=',os.getenv('OMP_NUM_THREADS') print ' OMP_NUM_THREADS=', os.getenv('OMP_NUM_THREADS')
print ' GOTO_NUM_THREADS=',os.getenv('GOTO_NUM_THREADS') print ' GOTO_NUM_THREADS=', os.getenv('GOTO_NUM_THREADS')
print print
print 'Numpy config:(used when the theano flags "blas.ldflags" is empty)' print ('Numpy config: (used when the theano flags'
numpy.show_config(); ' "blas.ldflags" is empty)')
print 'Numpy dot module:',numpy.dot.__module__; numpy.show_config()
print 'Numpy file location that was loaded:',numpy.__file__; print 'Numpy dot module:', numpy.dot.__module__
print 'Numpy version:',numpy.__version__ print 'Numpy file location that was loaded:', numpy.__file__
print 'Numpy version:', numpy.__version__
print print
if any( [x.op.__class__.__name__=='Gemm' for x in f.maker.env.toposort()]): if any([x.op.__class__.__name__ == 'Gemm' for x in
f.maker.env.toposort()]):
print 'Used the cpu' print 'Used the cpu'
elif any( [x.op.__class__.__name__=='GpuGemm' for x in f.maker.env.toposort()]): elif any([x.op.__class__.__name__ == 'GpuGemm' for x in
f.maker.env.toposort()]):
print 'Used the gpu' print 'Used the gpu'
else: else:
print 'ERROR, not able to tell if theano used the cpu or the gpu' print 'ERROR, not able to tell if theano used the cpu or the gpu'
print f.maker.env.toposort() print f.maker.env.toposort()
t0=0 t0 = 0
t1=-1 t1 = -1
if execute: if execute:
t0=time.time() t0 = time.time()
for i in range(iters): for i in range(iters):
f() f()
t1=time.time() t1 = time.time()
if verbose and execute: if verbose and execute:
print print
print 'This execution time took %.2fs'%(t1-t0) print 'This execution time took %.2fs' % (t1 - t0)
print print
print 'Try to run this script a few times. Experience show that the first time is not as fast as followings call. The difference is not big, but consistent.' print ('Try to run this script a few times. Experience show that'
return t1-t0 ' the first time is not as fast as followings call. The'
' difference is not big, but consistent.')
return t1 - t0
def jobman_job(state, channel): def jobman_job(state, channel):
execute() execute()
return channel.COMPLETE return channel.COMPLETE
def test(): def test():
execute() execute()
...@@ -92,11 +100,15 @@ if __name__ == "__main__": ...@@ -92,11 +100,15 @@ if __name__ == "__main__":
if verbose: if verbose:
print """ print """
Some results that you can compare against. They were 10 executions of gemm in float64 with matrices of shape 2000x2000. Some results that you can compare against. They were 10 executions
of gemm in float64 with matrices of shape 2000x2000.
CPU tested: Xeon E5345(2.33Ghz, 8M L2 cache, 1333Mhz FSB), Xeon E5430(2.66Ghz, 12M L2 cache, 1333Mhz FSB),
Xeon E5450(3Ghz, 12M L2 cache, 1333Mhz FSB), Xeon X5560(2.8Ghz, 12M L2 cache, 6.4GT/s QPI, hyper-threads enabled?) CPU tested: Xeon E5345(2.33Ghz, 8M L2 cache, 1333Mhz FSB),
Core 2 E8500, Core i7 930(2.8Ghz, hyper-threads enabled), Core i7 950(3.07GHz, hyper-threads enabled) Xeon E5430(2.66Ghz, 12M L2 cache, 1333Mhz FSB),
Xeon E5450(3Ghz, 12M L2 cache, 1333Mhz FSB),
Xeon X5560(2.8Ghz, 12M L2 cache, hyper-threads?)
Core 2 E8500, Core i7 930(2.8Ghz, hyper-threads enabled),
Core i7 950(3.07GHz, hyper-threads enabled)
Xeon X5550(2.67GHz, 8M l2 cache?, hyper-threads enabled) Xeon X5550(2.67GHz, 8M l2 cache?, hyper-threads enabled)
...@@ -132,7 +144,8 @@ if __name__ == "__main__": ...@@ -132,7 +144,8 @@ if __name__ == "__main__":
goto2 1.13/16 3.16s goto2 1.13/16 3.16s
Test time in float32 with cuda 3.0.14 Test time in float32 with cuda 3.0.14
(cuda version 3.2RC and up are supposed to have faster gemm on the GTX4?? card) (cuda version 3.2RC and up have a faster gemm on the Fermi/GTX[45]??
gpu/cuda version gpu/cuda version
GTX580/3.2 0.20s GTX580/3.2 0.20s
GTX480/3.2 0.24s GTX480/3.2 0.24s
...@@ -149,6 +162,7 @@ if __name__ == "__main__": ...@@ -149,6 +162,7 @@ if __name__ == "__main__":
""" """
print print
print "We timed",iters,"executions of gemm with matrix of shapes",shapes print "We timed", iters,
print "executions of gemm with matrix of shapes", shapes
else: else:
print t print t
...@@ -11,7 +11,7 @@ TIME_PREFIX=time ...@@ -11,7 +11,7 @@ TIME_PREFIX=time
VAR=OMP_NUM_THREADS VAR=OMP_NUM_THREADS
echo "numpy gemm take=" echo "numpy gemm take="
THEANO_FLAGS=blas.ldflags= $TIME_PREFIX python misc/check_blas.py --quiet THEANO_FLAGS=blas.ldflags= $TIME_PREFIX python misc/check_blas.py --quiet
for i in 1 2 4 8: for i in 1 2 4 8
do do
export $VAR=$i export $VAR=$i
x=`$TIME_PREFIX python misc/check_blas.py --quiet` x=`$TIME_PREFIX python misc/check_blas.py --quiet`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论