提交 77f61060 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Assert that test values (if any) of replaced nodes match with replacement

上级 601088bb
...@@ -6,6 +6,8 @@ types that it can raise ...@@ -6,6 +6,8 @@ types that it can raise
""" """
import sys import sys
import theano
from theano import gof
from theano.gof import graph from theano.gof import graph
from theano.gof import utils from theano.gof import utils
from theano.gof import toolbox from theano.gof import toolbox
...@@ -431,6 +433,21 @@ class FunctionGraph(utils.object2): ...@@ -431,6 +433,21 @@ class FunctionGraph(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops # because it makes it easier to implement some optimizations for multiple-output ops
return return
if theano.config.compute_test_value != 'off':
try:
tval = gof.op.get_test_value(r)
new_tval = gof.op.get_test_value(new_r)
except AttributeError:
pass
else:
if tval.shape != new_tval.shape:
raise AssertionError(
"The replacement variable has a test value with "
"a shape different from the original variable's "
"test value. Original: %s, new: %s"
% (tval.shape, new_tval.shape),
r, new_r, str(reason))
for node, i in list(r.clients): # copy the client list for iteration for node, i in list(r.clients): # copy the client list for iteration
assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r) assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason) self.change_input(node, i, new_r, reason=reason)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论