提交 97ab0274 authored 作者: ChienliMa's avatar ChienliMa

Codes should handle node with all Constant inputs

上级 9b635db7
......@@ -150,7 +150,7 @@ class OpFromGraph(gof.Op):
c_map = {}
num_of_input = len(fgraph.inputs)
# Initialize input connection pattern, each input affects itself
for index in range(num_of_input):
for index in xrange(num_of_input):
vec = [False] * num_of_input
vec[index] = True
# Make use of numpy.array to simplify codes
......@@ -163,6 +163,8 @@ class OpFromGraph(gof.Op):
for var in node.inputs:
if not isinstance(var, theano.Constant):
in_vecs.append(c_map[var])
else:
in_vecs.append(numpy.array([False] * num_of_input))
if not hasattr(node.op, 'connection_pattern'):
# By default, nodes inputs affect all outputs
......
......@@ -145,9 +145,10 @@ class T_OpFromGraph(unittest.TestCase):
out3 = 3 + rv_u
op3 = OpFromGraph([x, y], [out1, out2, out3], mode='FAST_RUN')
results = op2.connection_pattern(None)
results = op3.connection_pattern(None)
expect_result = [[True, False, False],
[False, True, False]]
[False, True, False],
[True, False, True]]
assert results == expect_result
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论