提交 c4e182ab authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1666 from nouiz/gh-1381

fix gh-1381 c linker crash with not used inputs.
......@@ -51,8 +51,8 @@ from theano import config
# of cutils_ext.
from theano.configparser import AddConfigVar, StrParam
AddConfigVar('gcc.cxxflags',
"Extra compiler flags for gcc",
StrParam(""))
"Extra compiler flags for gcc",
StrParam(""))
# gof imports
from theano.gof import graph
......@@ -432,7 +432,7 @@ class CLinker(link.Linker):
def accept(self, fgraph, no_recycling=None):
"""WRITEME"""
if no_recycling is None:
no_recycling = []
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)().accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is already"
......@@ -450,8 +450,13 @@ class CLinker(link.Linker):
fgraph = self.fgraph
self.inputs = fgraph.inputs
self.outputs = fgraph.outputs
# list(fgraph.variables)
self.variables = graph.variables(self.inputs, self.outputs)
# We need to include the not used inputs in our variables,
# otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(var.clients)]
self.variables += graph.variables(self.inputs, self.outputs)
# The orphans field is listified to ensure a consistent order.
#list(fgraph.orphans.difference(self.outputs))
self.orphans = list(r for r in self.variables
......@@ -528,7 +533,7 @@ class CLinker(link.Linker):
if isinstance(variable, graph.Constant):
try:
symbol[variable] = ("(" + variable.type.c_literal(
variable.data) + ")")
variable.data) + ")")
self.consts.append(variable)
self.orphans.remove(variable)
continue
......@@ -630,8 +635,8 @@ class CLinker(link.Linker):
else:
# The following will be executed if the "try" block succeeds
assert isinstance(c_support_code_apply[-1], basestring), (
str(node.op) +
" didn't return a string for c_support_code_apply")
str(node.op) +
" didn't return a string for c_support_code_apply")
try:
c_init_code_apply.append(op.c_init_code_apply(node, name))
except utils.MethodNotDefined:
......@@ -1036,14 +1041,15 @@ class CLinker(link.Linker):
no_recycle list.
"""
return self.cmodule_key_(self.fgraph, self.no_recycling,
compile_args=self.compile_args(),
libraries=self.libraries(),
header_dirs=self.header_dirs(),
c_compiler=self.c_compiler(),
)
def cmodule_key_(self, fgraph, no_recycling, compile_args=None, libraries=None,
header_dirs=None, insert_config_md5=True, c_compiler=None):
compile_args=self.compile_args(),
libraries=self.libraries(),
header_dirs=self.header_dirs(),
c_compiler=self.c_compiler(),
)
def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
libraries=None, header_dirs=None, insert_config_md5=True,
c_compiler=None):
"""
Do the actual computation of cmodule_key in a static method
to allow it to be reused in scalar.Composite.__eq__
......@@ -1059,7 +1065,7 @@ class CLinker(link.Linker):
# seen 'so far' in the loop below
fgraph_computed_set = set()
fgraph_inputs_dict = dict((i, (-1, pos)) for pos, i in
enumerate(fgraph.inputs))
enumerate(fgraph.inputs))
constant_ids = dict()
op_pos = {} # Apply -> topological position
......@@ -1167,7 +1173,7 @@ class CLinker(link.Linker):
sig.append((
node.op,
tuple((i.type, in_sig(i, node_pos, ipos))
for ipos, i in enumerate(node.inputs)),
for ipos, i in enumerate(node.inputs)),
(1, # Increment if cmodule change its handling of outputs
tuple(o in no_recycling for o in node.outputs))))
......@@ -1180,6 +1186,11 @@ class CLinker(link.Linker):
op_pos[node] = node_pos
fgraph_computed_set.update(node.outputs)
# Add not used input in the key
for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs)
if not len(var.clients)]:
sig.append((var.type, in_sig(var, -1, ipos)))
#crystalize the signature and version
sig = tuple(sig)
version = tuple(version)
......@@ -1458,10 +1469,10 @@ class OpWiseCLinker(link.LocalLinker):
__cache__ = {}
def __init__(self,
fallback_on_perform=True,
allow_gc=None,
nice_errors=True,
schedule=None):
fallback_on_perform=True,
allow_gc=None,
nice_errors=True,
schedule=None):
if allow_gc is None:
allow_gc = config.allow_gc
self.fgraph = None
......@@ -1476,10 +1487,10 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(
fallback_on_perform=self.fallback_on_perform,
allow_gc=self.allow_gc,
nice_errors=self.nice_errors
).accept(fgraph, no_recycling)
fallback_on_perform=self.fallback_on_perform,
allow_gc=self.allow_gc,
nice_errors=self.nice_errors
).accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is
#already tied to another FunctionGraph.")
self.fgraph = fgraph
......@@ -1500,7 +1511,7 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage)
fgraph, order, input_storage, output_storage)
if self.allow_gc:
computed, last_user = link.gc_helper(order)
post_thunk_old_storage = []
......@@ -1523,9 +1534,9 @@ class OpWiseCLinker(link.LocalLinker):
if theano.config.cxx:
node.op._op_use_c_code = True
thunks += [node.op.make_thunk(node,
storage_map,
compute_map,
no_recycling)]
storage_map,
compute_map,
no_recycling)]
thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs]
......@@ -1548,9 +1559,9 @@ class OpWiseCLinker(link.LocalLinker):
for r in no_recycling if r not in fgraph.inputs]
f = link.streamline(fgraph, thunks, order,
post_thunk_old_storage,
no_recycling=no_recycling,
nice_errors=self.nice_errors)
post_thunk_old_storage,
no_recycling=no_recycling,
nice_errors=self.nice_errors)
f.allow_gc = self.allow_gc
......
......@@ -3,8 +3,9 @@ import unittest
from nose.plugins.skip import SkipTest
import theano
from theano.gof.link import PerformLinker
from theano.gof.cc import *
from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker
from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op
......@@ -227,6 +228,18 @@ def test_clinker_dups():
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_not_used_inputs():
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
e = add(x, y)
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
assert fn(2.0, 1.5, 1.0) == 3.5
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_dups_inner():
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
......@@ -253,6 +266,7 @@ def test_opwiseclinker_straightforward():
# The python version of bad_sub always return -10.
assert fn(2.0, 2.0, 2.0) == -6
def test_opwiseclinker_constant():
x, y, z = inputs()
x = Constant(tdouble, 7.2, name='x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论