提交 c4e182ab authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1666 from nouiz/gh-1381

fix gh-1381 c linker crash with not used inputs.
...@@ -450,8 +450,13 @@ class CLinker(link.Linker): ...@@ -450,8 +450,13 @@ class CLinker(link.Linker):
fgraph = self.fgraph fgraph = self.fgraph
self.inputs = fgraph.inputs self.inputs = fgraph.inputs
self.outputs = fgraph.outputs self.outputs = fgraph.outputs
# list(fgraph.variables) # list(fgraph.variables)
self.variables = graph.variables(self.inputs, self.outputs) # We need to include the not used inputs in our variables,
# otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(var.clients)]
self.variables += graph.variables(self.inputs, self.outputs)
# The orphans field is listified to ensure a consistent order. # The orphans field is listified to ensure a consistent order.
#list(fgraph.orphans.difference(self.outputs)) #list(fgraph.orphans.difference(self.outputs))
self.orphans = list(r for r in self.variables self.orphans = list(r for r in self.variables
...@@ -1042,8 +1047,9 @@ class CLinker(link.Linker): ...@@ -1042,8 +1047,9 @@ class CLinker(link.Linker):
c_compiler=self.c_compiler(), c_compiler=self.c_compiler(),
) )
def cmodule_key_(self, fgraph, no_recycling, compile_args=None, libraries=None, def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
header_dirs=None, insert_config_md5=True, c_compiler=None): libraries=None, header_dirs=None, insert_config_md5=True,
c_compiler=None):
""" """
Do the actual computation of cmodule_key in a static method Do the actual computation of cmodule_key in a static method
to allow it to be reused in scalar.Composite.__eq__ to allow it to be reused in scalar.Composite.__eq__
...@@ -1180,6 +1186,11 @@ class CLinker(link.Linker): ...@@ -1180,6 +1186,11 @@ class CLinker(link.Linker):
op_pos[node] = node_pos op_pos[node] = node_pos
fgraph_computed_set.update(node.outputs) fgraph_computed_set.update(node.outputs)
# Add not used input in the key
for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs)
if not len(var.clients)]:
sig.append((var.type, in_sig(var, -1, ipos)))
#crystalize the signature and version #crystalize the signature and version
sig = tuple(sig) sig = tuple(sig)
version = tuple(version) version = tuple(version)
......
...@@ -3,8 +3,9 @@ import unittest ...@@ -3,8 +3,9 @@ import unittest
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import theano
from theano.gof.link import PerformLinker from theano.gof.link import PerformLinker
from theano.gof.cc import * from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
...@@ -227,6 +228,18 @@ def test_clinker_dups(): ...@@ -227,6 +228,18 @@ def test_clinker_dups():
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_not_used_inputs():
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
e = add(x, y)
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
assert fn(2.0, 1.5, 1.0) == 3.5
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_dups_inner(): def test_clinker_dups_inner():
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
...@@ -253,6 +266,7 @@ def test_opwiseclinker_straightforward(): ...@@ -253,6 +266,7 @@ def test_opwiseclinker_straightforward():
# The python version of bad_sub always return -10. # The python version of bad_sub always return -10.
assert fn(2.0, 2.0, 2.0) == -6 assert fn(2.0, 2.0, 2.0) == -6
def test_opwiseclinker_constant(): def test_opwiseclinker_constant():
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name='x') x = Constant(tdouble, 7.2, name='x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论