提交 834f0b6a authored 作者: James Bergstra's avatar James Bergstra

reasons are being passed around during optimization... with an ugly hack to env.execute_callback

上级 52dff24f
...@@ -297,10 +297,7 @@ class Env(utils.object2): ...@@ -297,10 +297,7 @@ class Env(utils.object2):
self.__import_r__([new_r]) self.__import_r__([new_r])
self.__add_clients__(new_r, [(node, i)]) self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False) prune = self.__remove_clients__(r, [(node, i)], False)
if reason is None: self.execute_callbacks('on_change_input', node, i, r, new_r, reason=reason)
self.execute_callbacks('on_change_input', node, i, r, new_r)
else:
self.execute_callbacks('on_change_input_with_reason', node, i, r, new_r, reason)
if prune: if prune:
self.__prune_r__([r]) self.__prune_r__([r])
...@@ -367,7 +364,7 @@ class Env(utils.object2): ...@@ -367,7 +364,7 @@ class Env(utils.object2):
### callback utils ### ### callback utils ###
def execute_callbacks(self, name, *args): def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME """WRITEME
Calls Calls
getattr(feature, name)(*args) getattr(feature, name)(*args)
...@@ -378,7 +375,16 @@ class Env(utils.object2): ...@@ -378,7 +375,16 @@ class Env(utils.object2):
fn = getattr(feature, name) fn = getattr(feature, name)
except AttributeError: except AttributeError:
continue continue
fn(self, *args)
#####HORRIBLE OPTIONAL ARGUMENT HACK
try:
fn(self, *args, **kwargs)
except TypeError, e:
if str(e) == "on_change_input() got an unexpected keyword argument 'reason'" and len(kwargs) == 1:
fn(self, *args)
else:
raise
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
"""WRITEME """WRITEME
......
...@@ -189,7 +189,7 @@ class MergeOptimizer(Optimizer): ...@@ -189,7 +189,7 @@ class MergeOptimizer(Optimizer):
# we adopt convention to keep the last name # we adopt convention to keep the last name
if c.name: if c.name:
other_c.name = c.name other_c.name = c.name
env.replace_validate(c, other_c) env.replace_validate(c, other_c, reason='Constant Merge')
else: else:
#this is a new constant #this is a new constant
const_sig[c] = sig const_sig[c] = sig
...@@ -219,7 +219,7 @@ class MergeOptimizer(Optimizer): ...@@ -219,7 +219,7 @@ class MergeOptimizer(Optimizer):
if output.name and not new_output.name: if output.name and not new_output.name:
new_output.name = output.name new_output.name = output.name
try: try:
env.replace_all_validate(pairs) env.replace_all_validate(pairs, reason='Merge (exptime)')
except InconsistencyError, e: except InconsistencyError, e:
success = False success = False
if not success: if not success:
...@@ -266,7 +266,7 @@ class MergeOptimizer(Optimizer): ...@@ -266,7 +266,7 @@ class MergeOptimizer(Optimizer):
if node_output.name: if node_output.name:
cand_output.name = node_output.name cand_output.name = node_output.name
try: try:
env.replace_all_validate(pairs) env.replace_all_validate(pairs, reason="Merge")
except InconsistencyError, e: except InconsistencyError, e:
success = False success = False
...@@ -714,7 +714,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -714,7 +714,7 @@ class NavigatorOptimizer(Optimizer):
return False return False
repl_pairs = zip(node.outputs, replacements) repl_pairs = zip(node.outputs, replacements)
try: try:
env.replace_all_validate(repl_pairs) env.replace_all_validate(repl_pairs, reason=lopt)
return True return True
except Exception, e: except Exception, e:
# This means the replacements were rejected by the env. # This means the replacements were rejected by the env.
......
...@@ -36,11 +36,11 @@ class History: ...@@ -36,11 +36,11 @@ class History:
del env.revert del env.revert
del self.history[env] del self.history[env]
def on_change_input(self, env, node, i, r, new_r): def on_change_input(self, env, node, i, r, new_r, reason=None):
if self.history[env] is None: if self.history[env] is None:
return return
h = self.history[env] h = self.history[env]
h.append(lambda: env.change_input(node, i, r)) h.append(lambda: env.change_input(node, i, r, reason=("Revert", reason)))
def revert(self, env, checkpoint): def revert(self, env, checkpoint):
""" """
...@@ -92,14 +92,14 @@ class ReplaceValidate(History, Validator): ...@@ -92,14 +92,14 @@ class ReplaceValidate(History, Validator):
del env.replace_validate del env.replace_validate
del env.replace_all_validate del env.replace_all_validate
def replace_validate(self, env, r, new_r): def replace_validate(self, env, r, new_r, reason=None):
self.replace_all_validate(env, [(r, new_r)]) self.replace_all_validate(env, [(r, new_r)], reason=reason)
def replace_all_validate(self, env, replacements): def replace_all_validate(self, env, replacements, reason=None):
chk = env.checkpoint() chk = env.checkpoint()
for r, new_r in replacements: for r, new_r in replacements:
try: try:
env.replace(r, new_r) env.replace(r, new_r, reason=reason)
except Exception, e: except Exception, e:
if 'The type of the replacement must be the same' not in str(e) and 'does not belong to this Env' not in str(e): if 'The type of the replacement must be the same' not in str(e) and 'does not belong to this Env' not in str(e):
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e
......
...@@ -64,7 +64,8 @@ def insert_inplace_optimizer(env): ...@@ -64,7 +64,8 @@ def insert_inplace_optimizer(env):
*[inplace_pattern.get(i, None) \ *[inplace_pattern.get(i, None) \
for i in xrange(len(node.outputs))])), for i in xrange(len(node.outputs))])),
inplace_pattern).make_node(*node.inputs) inplace_pattern).make_node(*node.inputs)
env.replace_all_validate(zip(node.outputs, new.outputs)) env.replace_all_validate(zip(node.outputs, new.outputs),
reason="insert_inplace_optimizer")
except (ValueError, TypeError, InconsistencyError), e: except (ValueError, TypeError, InconsistencyError), e:
continue continue
candidate_inputs.remove(candidate_input) candidate_inputs.remove(candidate_input)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论