提交 cd6f4ef9 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make fg.replace error give more information

上级 95353815
...@@ -15,6 +15,7 @@ from theano.gof import toolbox ...@@ -15,6 +15,7 @@ from theano.gof import toolbox
from theano import config from theano import config
from six import iteritems, itervalues from six import iteritems, itervalues
from six.moves import StringIO
from theano.gof.utils import get_variable_trace_string from theano.gof.utils import get_variable_trace_string
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
NullType = None NullType = None
...@@ -468,10 +469,20 @@ class FunctionGraph(utils.object2): ...@@ -468,10 +469,20 @@ class FunctionGraph(utils.object2):
new_r2 = r.type.convert_variable(new_r) new_r2 = r.type.convert_variable(new_r)
# We still make sure that the type converts correctly # We still make sure that the type converts correctly
if new_r2 is None or new_r2.type != r.type: if new_r2 is None or new_r2.type != r.type:
raise TypeError("The type of the replacement must be " done = dict()
"compatible with the type of the original " used_ids = dict()
"Variable.", r, new_r, r.type, new_r.type, old = theano.compile.debugmode.debugprint(
str(reason)) r, prefix=' ', depth=6,
file=StringIO(), done=done,
print_type=True,
used_ids=used_ids).getvalue()
new = theano.compile.debugmode.debugprint(
new_r, prefix=' ', depth=6,
file=StringIO(), done=done,
print_type=True,
used_ids=used_ids).getvalue()
raise toolbox.BadOptimization(
r, new_r, None, None, reason, old, new)
new_r = new_r2 new_r = new_r2
if r not in self.variables: if r not in self.variables:
# this variable isn't in the graph... don't raise an # this variable isn't in the graph... don't raise an
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论