提交 d604b4be authored 作者: Frederic's avatar Frederic

Allow an optimizer to make sure its replacement remove a Variable/Apply node in the graph.

When it is not the case, don't apply the optimization and warn about it.
上级 0c02c4be
......@@ -33,7 +33,8 @@ from optdb import \
EquilibriumDB, SequenceDB, ProxyDB
from toolbox import \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\
PrintListener, ReplacementDidntRemovedError
from type import \
Type, Generic, generic
......
......@@ -9,6 +9,15 @@ class AlreadyThere(Exception):
pass
class ReplacementDidntRemovedError(Exception):
"""This exception should be thrown by replace_all_validate_remove
when an optimization wanted to remove a Variable or a Node from
the graph, but the replacement it gived didn't do that.
"""
pass
class Bookkeeper:
def on_attach(self, env):
......@@ -91,12 +100,15 @@ class ReplaceValidate(History, Validator):
" or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env)
env.replace_all_validate_remove = partial(
self.replace_all_validate_remove, env)
def on_detach(self, env):
History.on_detach(self, env)
Validator.on_detach(self, env)
del env.replace_validate
del env.replace_all_validate
del env.replace_all_validate_remove
def replace_validate(self, env, r, new_r, reason=None):
self.replace_all_validate(env, [(r, new_r)], reason=reason)
......@@ -121,6 +133,28 @@ class ReplaceValidate(History, Validator):
except Exception, e:
env.revert(chk)
raise
return chk
def replace_all_validate_remove(self, env, replacements,
remove, reason=None):
"""As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. It also print a warning.
"""
chk = env.replace_all_validate(replacements, reason)
for rm in remove:
if rm in env.nodes or rm in env.variables:
env.revert(chk)
out = sys.stderr
print >> out, (
"WARNING: An optimization wanted to replace a Variable"
" in the graph, but the replacement for it doesn't"
" remove it. We disabled the optimization."
" Your function runs correctly, but it would be"
" appreciated if you submit this problem to the mailing"
" list theano-users so that we can fix it.")
print >> out, reason, replacements
raise ReplacementDidntRemovedError()
class NodeFinder(dict, Bookkeeper):
......
......@@ -133,8 +133,10 @@ import numpy.distutils
from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, view_roots, DestroyHandler,
local_optimizer, Optimizer,
InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer, Apply)
local_optimizer, Optimizer,
InconsistencyError, toolbox, SequenceDB,
EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError)
from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb
from theano.gof.python25 import all, any
......@@ -1022,7 +1024,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
#print 'GEMM 0', rval, beta, L, alpha, M
return rval
return rval, M
# it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things.
......@@ -1035,7 +1037,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle(0, 'x'),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle(0)]
return rval
return rval, MM
if tuple(M.owner.op.new_order) == (1,):
# it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22:
......@@ -1043,7 +1045,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle('x', 0),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle(1)]
return rval
return rval, MM
if tuple(M.owner.op.new_order) == ():
# it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22:
......@@ -1051,7 +1053,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle('x', 'x'),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle()]
return rval
return rval, MM
# this is False'd out because of inadequate testing.
# TODO see ticket #237
......@@ -1085,7 +1087,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
if recurse_flip:
return _beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip=False)
else:
return False
return False, False
def _gemm_canonicalize(r, scale, rval, maxclients):
......@@ -1250,7 +1252,8 @@ def _gemm_from_factored_list(lst):
#print 'TRYING', (s_i, M_i, s_j, M_j)
gemm_of_sM_list = _beta_L_plus_alpha_M(s_i, M_i, s_j, M_j)
gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M(s_i, M_i,
s_j, M_j)
#print 'GOT IT', gemm_of_sM_list
if gemm_of_sM_list:
def item_to_var(t):
......@@ -1273,7 +1276,7 @@ def _gemm_from_factored_list(lst):
else:
rval = add_inputs
#print "RETURNING GEMM THIGN", rval
return rval
return rval, old_dot22
def _gemm_from_node2(node):
......@@ -1301,7 +1304,7 @@ def _gemm_from_node2(node):
# http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5,
# but never made it into a trac ticket.
if rval and (rval[0].type == node.outputs[0].type):
if rval and (rval[0][0].type == node.outputs[0].type):
return rval
......@@ -1326,11 +1329,13 @@ class GemmOptimizer(Optimizer):
except InconsistencyError, e:
continue
if new_outputs:
new_outputs, old_dot22 = new_outputs
assert len(new_outputs) == len(node.outputs)
try:
env.replace_all_validate(
zip(node.outputs, new_outputs),
reason='GemmOptimizer'
env.replace_all_validate_remove(
zip(node.outputs, new_outputs),
[old_dot22],
reason='GemmOptimizer'
)
did_something = True
break
......@@ -1338,6 +1343,8 @@ class GemmOptimizer(Optimizer):
# TODO: retry other applications of gemm (see comment
# in _gemm_from_node)
pass
except ReplacementDidntRemovedError, e:
pass
class Dot22(GemmRelated):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论