提交 be95a5ce authored 作者: James Bergstra's avatar James Bergstra

Merge pull request #684 from nouiz/blas

Improve gemm optimizer to never accidentally duplicate work.
...@@ -33,7 +33,8 @@ from optdb import \ ...@@ -33,7 +33,8 @@ from optdb import \
EquilibriumDB, SequenceDB, ProxyDB EquilibriumDB, SequenceDB, ProxyDB
from toolbox import \ from toolbox import \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\
PrintListener, ReplacementDidntRemovedError
from type import \ from type import \
Type, Generic, generic Type, Generic, generic
......
...@@ -9,6 +9,15 @@ class AlreadyThere(Exception): ...@@ -9,6 +9,15 @@ class AlreadyThere(Exception):
pass pass
class ReplacementDidntRemovedError(Exception):
"""This exception should be thrown by replace_all_validate_remove
when an optimization wanted to remove a Variable or a Node from
the graph, but the replacement it gived didn't do that.
"""
pass
class Bookkeeper: class Bookkeeper:
def on_attach(self, env): def on_attach(self, env):
...@@ -91,12 +100,15 @@ class ReplaceValidate(History, Validator): ...@@ -91,12 +100,15 @@ class ReplaceValidate(History, Validator):
" or in conflict with another plugin.") " or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env) env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env) env.replace_all_validate = partial(self.replace_all_validate, env)
env.replace_all_validate_remove = partial(
self.replace_all_validate_remove, env)
def on_detach(self, env): def on_detach(self, env):
History.on_detach(self, env) History.on_detach(self, env)
Validator.on_detach(self, env) Validator.on_detach(self, env)
del env.replace_validate del env.replace_validate
del env.replace_all_validate del env.replace_all_validate
del env.replace_all_validate_remove
def replace_validate(self, env, r, new_r, reason=None): def replace_validate(self, env, r, new_r, reason=None):
self.replace_all_validate(env, [(r, new_r)], reason=reason) self.replace_all_validate(env, [(r, new_r)], reason=reason)
...@@ -121,6 +133,28 @@ class ReplaceValidate(History, Validator): ...@@ -121,6 +133,28 @@ class ReplaceValidate(History, Validator):
except Exception, e: except Exception, e:
env.revert(chk) env.revert(chk)
raise raise
return chk
def replace_all_validate_remove(self, env, replacements,
remove, reason=None):
"""As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. It also print a warning.
"""
chk = env.replace_all_validate(replacements, reason)
for rm in remove:
if rm in env.nodes or rm in env.variables:
env.revert(chk)
out = sys.stderr
print >> out, (
"WARNING: An optimization wanted to replace a Variable"
" in the graph, but the replacement for it doesn't"
" remove it. We disabled the optimization."
" Your function runs correctly, but it would be"
" appreciated if you submit this problem to the mailing"
" list theano-users so that we can fix it.")
print >> out, reason, replacements
raise ReplacementDidntRemovedError()
class NodeFinder(dict, Bookkeeper): class NodeFinder(dict, Bookkeeper):
......
...@@ -133,8 +133,10 @@ import numpy.distutils ...@@ -133,8 +133,10 @@ import numpy.distutils
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, view_roots, DestroyHandler, from theano.gof import (utils, Op, view_roots, DestroyHandler,
local_optimizer, Optimizer, local_optimizer, Optimizer,
InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer, Apply) InconsistencyError, toolbox, SequenceDB,
EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError)
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb from theano.compile.mode import optdb
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
...@@ -1022,7 +1024,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1022,7 +1024,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
Ml, Mr = M.owner.inputs Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
#print 'GEMM 0', rval, beta, L, alpha, M #print 'GEMM 0', rval, beta, L, alpha, M
return rval return rval, M
# it also might be the case that there is a dimshuffle between the + # it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things. # and the dot22. local_dot_to_dot22 in particular will put in such things.
...@@ -1035,7 +1037,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1035,7 +1037,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle(0, 'x'), g = gemm_no_inplace(L.dimshuffle(0, 'x'),
alpha, MMl, MMr, beta) alpha, MMl, MMr, beta)
rval = [g.dimshuffle(0)] rval = [g.dimshuffle(0)]
return rval return rval, MM
if tuple(M.owner.op.new_order) == (1,): if tuple(M.owner.op.new_order) == (1,):
# it is making a row MM into a vector # it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22: if MM.owner and MM.owner.op == _dot22:
...@@ -1043,7 +1045,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1043,7 +1045,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle('x', 0), g = gemm_no_inplace(L.dimshuffle('x', 0),
alpha, MMl, MMr, beta) alpha, MMl, MMr, beta)
rval = [g.dimshuffle(1)] rval = [g.dimshuffle(1)]
return rval return rval, MM
if tuple(M.owner.op.new_order) == (): if tuple(M.owner.op.new_order) == ():
# it is making a row MM into a vector # it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22: if MM.owner and MM.owner.op == _dot22:
...@@ -1051,7 +1053,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1051,7 +1053,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle('x', 'x'), g = gemm_no_inplace(L.dimshuffle('x', 'x'),
alpha, MMl, MMr, beta) alpha, MMl, MMr, beta)
rval = [g.dimshuffle()] rval = [g.dimshuffle()]
return rval return rval, MM
# this is False'd out because of inadequate testing. # this is False'd out because of inadequate testing.
# TODO see ticket #237 # TODO see ticket #237
...@@ -1085,7 +1087,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1085,7 +1087,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
if recurse_flip: if recurse_flip:
return _beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip=False) return _beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip=False)
else: else:
return False return False, False
def _gemm_canonicalize(r, scale, rval, maxclients): def _gemm_canonicalize(r, scale, rval, maxclients):
...@@ -1250,7 +1252,8 @@ def _gemm_from_factored_list(lst): ...@@ -1250,7 +1252,8 @@ def _gemm_from_factored_list(lst):
#print 'TRYING', (s_i, M_i, s_j, M_j) #print 'TRYING', (s_i, M_i, s_j, M_j)
gemm_of_sM_list = _beta_L_plus_alpha_M(s_i, M_i, s_j, M_j) gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M(s_i, M_i,
s_j, M_j)
#print 'GOT IT', gemm_of_sM_list #print 'GOT IT', gemm_of_sM_list
if gemm_of_sM_list: if gemm_of_sM_list:
def item_to_var(t): def item_to_var(t):
...@@ -1273,7 +1276,7 @@ def _gemm_from_factored_list(lst): ...@@ -1273,7 +1276,7 @@ def _gemm_from_factored_list(lst):
else: else:
rval = add_inputs rval = add_inputs
#print "RETURNING GEMM THIGN", rval #print "RETURNING GEMM THIGN", rval
return rval return rval, old_dot22
def _gemm_from_node2(node): def _gemm_from_node2(node):
...@@ -1301,7 +1304,7 @@ def _gemm_from_node2(node): ...@@ -1301,7 +1304,7 @@ def _gemm_from_node2(node):
# http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5, # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5,
# but never made it into a trac ticket. # but never made it into a trac ticket.
if rval and (rval[0].type == node.outputs[0].type): if rval and (rval[0][0].type == node.outputs[0].type):
return rval return rval
...@@ -1326,11 +1329,13 @@ class GemmOptimizer(Optimizer): ...@@ -1326,11 +1329,13 @@ class GemmOptimizer(Optimizer):
except InconsistencyError, e: except InconsistencyError, e:
continue continue
if new_outputs: if new_outputs:
new_outputs, old_dot22 = new_outputs
assert len(new_outputs) == len(node.outputs) assert len(new_outputs) == len(node.outputs)
try: try:
env.replace_all_validate( env.replace_all_validate_remove(
zip(node.outputs, new_outputs), zip(node.outputs, new_outputs),
reason='GemmOptimizer' [old_dot22],
reason='GemmOptimizer'
) )
did_something = True did_something = True
break break
...@@ -1338,6 +1343,8 @@ class GemmOptimizer(Optimizer): ...@@ -1338,6 +1343,8 @@ class GemmOptimizer(Optimizer):
# TODO: retry other applications of gemm (see comment # TODO: retry other applications of gemm (see comment
# in _gemm_from_node) # in _gemm_from_node)
pass pass
except ReplacementDidntRemovedError, e:
pass
class Dot22(GemmRelated): class Dot22(GemmRelated):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论