提交 c4e182ab authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1666 from nouiz/gh-1381

fix gh-1381 c linker crash with not used inputs.
......@@ -450,8 +450,13 @@ class CLinker(link.Linker):
fgraph = self.fgraph
self.inputs = fgraph.inputs
self.outputs = fgraph.outputs
# list(fgraph.variables)
self.variables = graph.variables(self.inputs, self.outputs)
# We need to include the not used inputs in our variables,
# otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(var.clients)]
self.variables += graph.variables(self.inputs, self.outputs)
# The orphans field is listified to ensure a consistent order.
#list(fgraph.orphans.difference(self.outputs))
self.orphans = list(r for r in self.variables
......@@ -1042,8 +1047,9 @@ class CLinker(link.Linker):
c_compiler=self.c_compiler(),
)
def cmodule_key_(self, fgraph, no_recycling, compile_args=None, libraries=None,
header_dirs=None, insert_config_md5=True, c_compiler=None):
def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
libraries=None, header_dirs=None, insert_config_md5=True,
c_compiler=None):
"""
Do the actual computation of cmodule_key in a static method
to allow it to be reused in scalar.Composite.__eq__
......@@ -1180,6 +1186,11 @@ class CLinker(link.Linker):
op_pos[node] = node_pos
fgraph_computed_set.update(node.outputs)
# Add not used input in the key
for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs)
if not len(var.clients)]:
sig.append((var.type, in_sig(var, -1, ipos)))
#crystalize the signature and version
sig = tuple(sig)
version = tuple(version)
......
......@@ -3,8 +3,9 @@ import unittest
from nose.plugins.skip import SkipTest
import theano
from theano.gof.link import PerformLinker
from theano.gof.cc import *
from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker
from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op
......@@ -227,6 +228,18 @@ def test_clinker_dups():
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_not_used_inputs():
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
e = add(x, y)
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
assert fn(2.0, 1.5, 1.0) == 3.5
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_dups_inner():
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
......@@ -253,6 +266,7 @@ def test_opwiseclinker_straightforward():
# The python version of bad_sub always return -10.
assert fn(2.0, 2.0, 2.0) == -6
def test_opwiseclinker_constant():
x, y, z = inputs()
x = Constant(tdouble, 7.2, name='x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论