提交 316ac334 authored 作者: Frederic's avatar Frederic

fix gh-1381 c linker crash with not used inputs.

上级 b1f1e62a
......@@ -450,8 +450,13 @@ class CLinker(link.Linker):
fgraph = self.fgraph
self.inputs = fgraph.inputs
self.outputs = fgraph.outputs
# list(fgraph.variables)
self.variables = graph.variables(self.inputs, self.outputs)
# We need to include the not used inputs in our variables,
# otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(var.clients)]
self.variables += graph.variables(self.inputs, self.outputs)
# The orphans field is listified to ensure a consistent order.
#list(fgraph.orphans.difference(self.outputs))
self.orphans = list(r for r in self.variables
......@@ -1180,6 +1185,11 @@ class CLinker(link.Linker):
op_pos[node] = node_pos
fgraph_computed_set.update(node.outputs)
# Add not used input in the key
for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs)
if not len(var.clients)]:
sig.append((var.type, in_sig(var, node_pos, ipos)))
#crystalize the signature and version
sig = tuple(sig)
version = tuple(version)
......
......@@ -227,6 +227,18 @@ def test_clinker_dups():
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_not_used_inputs():
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
e = add(x, y)
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
assert fn(2.0, 1.5, 1.0) == 3.5
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_dups_inner():
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论