提交 c4e182ab authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1666 from nouiz/gh-1381

fix gh-1381 c linker crash with not used inputs.
...@@ -51,8 +51,8 @@ from theano import config ...@@ -51,8 +51,8 @@ from theano import config
# of cutils_ext. # of cutils_ext.
from theano.configparser import AddConfigVar, StrParam from theano.configparser import AddConfigVar, StrParam
AddConfigVar('gcc.cxxflags', AddConfigVar('gcc.cxxflags',
"Extra compiler flags for gcc", "Extra compiler flags for gcc",
StrParam("")) StrParam(""))
# gof imports # gof imports
from theano.gof import graph from theano.gof import graph
...@@ -432,7 +432,7 @@ class CLinker(link.Linker): ...@@ -432,7 +432,7 @@ class CLinker(link.Linker):
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
"""WRITEME""" """WRITEME"""
if no_recycling is None: if no_recycling is None:
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)().accept(fgraph, no_recycling) return type(self)().accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is already" #raise Exception("Cannot accept from a Linker that is already"
...@@ -450,8 +450,13 @@ class CLinker(link.Linker): ...@@ -450,8 +450,13 @@ class CLinker(link.Linker):
fgraph = self.fgraph fgraph = self.fgraph
self.inputs = fgraph.inputs self.inputs = fgraph.inputs
self.outputs = fgraph.outputs self.outputs = fgraph.outputs
# list(fgraph.variables) # list(fgraph.variables)
self.variables = graph.variables(self.inputs, self.outputs) # We need to include the not used inputs in our variables,
# otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(var.clients)]
self.variables += graph.variables(self.inputs, self.outputs)
# The orphans field is listified to ensure a consistent order. # The orphans field is listified to ensure a consistent order.
#list(fgraph.orphans.difference(self.outputs)) #list(fgraph.orphans.difference(self.outputs))
self.orphans = list(r for r in self.variables self.orphans = list(r for r in self.variables
...@@ -528,7 +533,7 @@ class CLinker(link.Linker): ...@@ -528,7 +533,7 @@ class CLinker(link.Linker):
if isinstance(variable, graph.Constant): if isinstance(variable, graph.Constant):
try: try:
symbol[variable] = ("(" + variable.type.c_literal( symbol[variable] = ("(" + variable.type.c_literal(
variable.data) + ")") variable.data) + ")")
self.consts.append(variable) self.consts.append(variable)
self.orphans.remove(variable) self.orphans.remove(variable)
continue continue
...@@ -630,8 +635,8 @@ class CLinker(link.Linker): ...@@ -630,8 +635,8 @@ class CLinker(link.Linker):
else: else:
# The following will be executed if the "try" block succeeds # The following will be executed if the "try" block succeeds
assert isinstance(c_support_code_apply[-1], basestring), ( assert isinstance(c_support_code_apply[-1], basestring), (
str(node.op) + str(node.op) +
" didn't return a string for c_support_code_apply") " didn't return a string for c_support_code_apply")
try: try:
c_init_code_apply.append(op.c_init_code_apply(node, name)) c_init_code_apply.append(op.c_init_code_apply(node, name))
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -1036,14 +1041,15 @@ class CLinker(link.Linker): ...@@ -1036,14 +1041,15 @@ class CLinker(link.Linker):
no_recycle list. no_recycle list.
""" """
return self.cmodule_key_(self.fgraph, self.no_recycling, return self.cmodule_key_(self.fgraph, self.no_recycling,
compile_args=self.compile_args(), compile_args=self.compile_args(),
libraries=self.libraries(), libraries=self.libraries(),
header_dirs=self.header_dirs(), header_dirs=self.header_dirs(),
c_compiler=self.c_compiler(), c_compiler=self.c_compiler(),
) )
def cmodule_key_(self, fgraph, no_recycling, compile_args=None, libraries=None, def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
header_dirs=None, insert_config_md5=True, c_compiler=None): libraries=None, header_dirs=None, insert_config_md5=True,
c_compiler=None):
""" """
Do the actual computation of cmodule_key in a static method Do the actual computation of cmodule_key in a static method
to allow it to be reused in scalar.Composite.__eq__ to allow it to be reused in scalar.Composite.__eq__
...@@ -1059,7 +1065,7 @@ class CLinker(link.Linker): ...@@ -1059,7 +1065,7 @@ class CLinker(link.Linker):
# seen 'so far' in the loop below # seen 'so far' in the loop below
fgraph_computed_set = set() fgraph_computed_set = set()
fgraph_inputs_dict = dict((i, (-1, pos)) for pos, i in fgraph_inputs_dict = dict((i, (-1, pos)) for pos, i in
enumerate(fgraph.inputs)) enumerate(fgraph.inputs))
constant_ids = dict() constant_ids = dict()
op_pos = {} # Apply -> topological position op_pos = {} # Apply -> topological position
...@@ -1167,7 +1173,7 @@ class CLinker(link.Linker): ...@@ -1167,7 +1173,7 @@ class CLinker(link.Linker):
sig.append(( sig.append((
node.op, node.op,
tuple((i.type, in_sig(i, node_pos, ipos)) tuple((i.type, in_sig(i, node_pos, ipos))
for ipos, i in enumerate(node.inputs)), for ipos, i in enumerate(node.inputs)),
(1, # Increment if cmodule change its handling of outputs (1, # Increment if cmodule change its handling of outputs
tuple(o in no_recycling for o in node.outputs)))) tuple(o in no_recycling for o in node.outputs))))
...@@ -1180,6 +1186,11 @@ class CLinker(link.Linker): ...@@ -1180,6 +1186,11 @@ class CLinker(link.Linker):
op_pos[node] = node_pos op_pos[node] = node_pos
fgraph_computed_set.update(node.outputs) fgraph_computed_set.update(node.outputs)
# Add not used input in the key
for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs)
if not len(var.clients)]:
sig.append((var.type, in_sig(var, -1, ipos)))
#crystalize the signature and version #crystalize the signature and version
sig = tuple(sig) sig = tuple(sig)
version = tuple(version) version = tuple(version)
...@@ -1458,10 +1469,10 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1458,10 +1469,10 @@ class OpWiseCLinker(link.LocalLinker):
__cache__ = {} __cache__ = {}
def __init__(self, def __init__(self,
fallback_on_perform=True, fallback_on_perform=True,
allow_gc=None, allow_gc=None,
nice_errors=True, nice_errors=True,
schedule=None): schedule=None):
if allow_gc is None: if allow_gc is None:
allow_gc = config.allow_gc allow_gc = config.allow_gc
self.fgraph = None self.fgraph = None
...@@ -1476,10 +1487,10 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1476,10 +1487,10 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)( return type(self)(
fallback_on_perform=self.fallback_on_perform, fallback_on_perform=self.fallback_on_perform,
allow_gc=self.allow_gc, allow_gc=self.allow_gc,
nice_errors=self.nice_errors nice_errors=self.nice_errors
).accept(fgraph, no_recycling) ).accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is #raise Exception("Cannot accept from a Linker that is
#already tied to another FunctionGraph.") #already tied to another FunctionGraph.")
self.fgraph = fgraph self.fgraph = fgraph
...@@ -1500,7 +1511,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1500,7 +1511,7 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage) fgraph, order, input_storage, output_storage)
if self.allow_gc: if self.allow_gc:
computed, last_user = link.gc_helper(order) computed, last_user = link.gc_helper(order)
post_thunk_old_storage = [] post_thunk_old_storage = []
...@@ -1523,9 +1534,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1523,9 +1534,9 @@ class OpWiseCLinker(link.LocalLinker):
if theano.config.cxx: if theano.config.cxx:
node.op._op_use_c_code = True node.op._op_use_c_code = True
thunks += [node.op.make_thunk(node, thunks += [node.op.make_thunk(node,
storage_map, storage_map,
compute_map, compute_map,
no_recycling)] no_recycling)]
thunks[-1].inputs = [storage_map[v] for v in node.inputs] thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs]
...@@ -1548,9 +1559,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1548,9 +1559,9 @@ class OpWiseCLinker(link.LocalLinker):
for r in no_recycling if r not in fgraph.inputs] for r in no_recycling if r not in fgraph.inputs]
f = link.streamline(fgraph, thunks, order, f = link.streamline(fgraph, thunks, order,
post_thunk_old_storage, post_thunk_old_storage,
no_recycling=no_recycling, no_recycling=no_recycling,
nice_errors=self.nice_errors) nice_errors=self.nice_errors)
f.allow_gc = self.allow_gc f.allow_gc = self.allow_gc
......
...@@ -3,8 +3,9 @@ import unittest ...@@ -3,8 +3,9 @@ import unittest
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import theano
from theano.gof.link import PerformLinker from theano.gof.link import PerformLinker
from theano.gof.cc import * from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
...@@ -227,6 +228,18 @@ def test_clinker_dups(): ...@@ -227,6 +228,18 @@ def test_clinker_dups():
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_not_used_inputs():
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
e = add(x, y)
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
assert fn(2.0, 1.5, 1.0) == 3.5
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_dups_inner(): def test_clinker_dups_inner():
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
...@@ -253,6 +266,7 @@ def test_opwiseclinker_straightforward(): ...@@ -253,6 +266,7 @@ def test_opwiseclinker_straightforward():
# The python version of bad_sub always return -10. # The python version of bad_sub always return -10.
assert fn(2.0, 2.0, 2.0) == -6 assert fn(2.0, 2.0, 2.0) == -6
def test_opwiseclinker_constant(): def test_opwiseclinker_constant():
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name='x') x = Constant(tdouble, 7.2, name='x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论