提交 77f61060 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Assert that test values (if any) of replaced nodes match with replacement

上级 601088bb
......@@ -6,6 +6,8 @@ types that it can raise
"""
import sys
import theano
from theano import gof
from theano.gof import graph
from theano.gof import utils
from theano.gof import toolbox
......@@ -431,6 +433,21 @@ class FunctionGraph(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops
return
if theano.config.compute_test_value != 'off':
try:
tval = gof.op.get_test_value(r)
new_tval = gof.op.get_test_value(new_r)
except AttributeError:
pass
else:
if tval.shape != new_tval.shape:
raise AssertionError(
"The replacement variable has a test value with "
"a shape different from the original variable's "
"test value. Original: %s, new: %s"
% (tval.shape, new_tval.shape),
r, new_r, str(reason))
for node, i in list(r.clients): # copy the client list for iteration
assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论