提交 eb535917 authored 作者: James Bergstra's avatar James Bergstra

DebugMode - modified linker to override TensorType.filter_checks_isfinite only

for the duration of a single function evaluation, rather than globally at the creation of a DebugMode function
上级 29efb1ad
...@@ -868,6 +868,11 @@ class _Linker(gof.link.LocalLinker): ...@@ -868,6 +868,11 @@ class _Linker(gof.link.LocalLinker):
return self return self
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler = None, input_storage = None, output_storage = None):
if 1:
#can't import at toplevel because of circular import
# TODO: don't do this ugly hacky way of setting the filter_checks_isfinite
from theano.tensor import TensorType #to set filter_check_isfinite
env = self.env env = self.env
input_storage_ = input_storage input_storage_ = input_storage
output_storage_ = output_storage output_storage_ = output_storage
...@@ -932,7 +937,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -932,7 +937,7 @@ class _Linker(gof.link.LocalLinker):
# This is the function that runs when you evaluate the graph # This is the function that runs when you evaluate the graph
##### #####
def f(): def f():
debug("starting f") debug("starting a DebugMode call")
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
...@@ -1027,7 +1032,10 @@ class _Linker(gof.link.LocalLinker): ...@@ -1027,7 +1032,10 @@ class _Linker(gof.link.LocalLinker):
storage_map[r][0] = _lessbroken_deepcopy(r_vals[r]) storage_map[r][0] = _lessbroken_deepcopy(r_vals[r])
debug(i, "DEBUGMODE running thunk_c") debug(i, "DEBUGMODE running thunk_c")
thunk_c() try:
thunk_c()
except:
raise_with_op(node)
for r in node.outputs: for r in node.outputs:
# check output values for type-correctness # check output values for type-correctness
...@@ -1075,9 +1083,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -1075,9 +1083,6 @@ class _Linker(gof.link.LocalLinker):
if True: if True:
gc.collect() gc.collect()
#except:
# raise_with_op(node)
_find_bad_optimizations(order, env.equivalence_tracker.reasons, r_vals) _find_bad_optimizations(order, env.equivalence_tracker.reasons, r_vals)
##### #####
...@@ -1132,10 +1137,27 @@ class _Linker(gof.link.LocalLinker): ...@@ -1132,10 +1137,27 @@ class _Linker(gof.link.LocalLinker):
if (r.owner is None): if (r.owner is None):
assert storage_map[r][0] is not None assert storage_map[r][0] is not None
############### ###############
# Done f # Done debugmode function call 'f'
############## ##############
def run_with_tensortype_filter_check(f):
def deco():
# WARNING: this is a global mechanism...
# so it will screw up if we are trying to use
# multiple modes at once.
old_filter_checks_isfinite = TensorType.filter_checks_isfinite
TensorType.filter_checks_isfinite = self.maker.mode.check_isfinite
try:
return f()
finally:
# put back the filter_checks_isfinite
TensorType.filter_checks_isfinite = old_filter_checks_isfinite
return deco
f = run_with_tensortype_filter_check(f)
f.allow_gc = True f.allow_gc = True
assert len(env.inputs) == len(input_storage) assert len(env.inputs) == len(input_storage)
assert len(env.outputs) == len(output_storage) assert len(env.outputs) == len(output_storage)
...@@ -1170,11 +1192,6 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions ...@@ -1170,11 +1192,6 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions
""" """
# WARNING: this is a global mechanism... so it will screw up if we are trying to use
# multiple modes at once.
from theano.tensor import TensorType #to set filter_check_isfinite
TensorType.filter_checks_isfinite = mode.check_isfinite
# Handle the case where inputs and/or outputs is a single Variable (not in a list) # Handle the case where inputs and/or outputs is a single Variable (not in a list)
unpack_single = False unpack_single = False
return_none = False return_none = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论