提交 c75f7892 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3003 from nouiz/fg_crash

FunctionGraph crash
...@@ -112,7 +112,9 @@ class FunctionGraph(utils.object2): ...@@ -112,7 +112,9 @@ class FunctionGraph(utils.object2):
# outputs are cached in this field # outputs are cached in this field
self.apply_nodes = set() self.apply_nodes = set()
# Ditto for variable nodes # Ditto for variable nodes.
# It must contain all fgraph.inputs and all apply_nodes
# outputs even if they aren't used in the graph.
self.variables = set() self.variables = set()
self.inputs = list(inputs) self.inputs = list(inputs)
...@@ -131,7 +133,8 @@ class FunctionGraph(utils.object2): ...@@ -131,7 +133,8 @@ class FunctionGraph(utils.object2):
self.__setup_r__(input) self.__setup_r__(input)
self.variables.add(input) self.variables.add(input)
self.__import_r__(outputs, reason="init") for output in outputs:
self.__import_r__(output, reason="init")
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
output.clients.append(('output', i)) output.clients.append(('output', i))
...@@ -233,45 +236,39 @@ class FunctionGraph(utils.object2): ...@@ -233,45 +236,39 @@ class FunctionGraph(utils.object2):
""" """
for entry in clients_to_remove: for entry in clients_to_remove:
r.clients.remove(entry) r.clients.remove(entry)
if entry in r.clients:
print('ERROR: DUPLICATE CLIENT ENTRY...', file=sys.stderr)
print(' ENTRY', repr(entry), type(entry[0]), file=sys.stderr)
print(' CLIENTS', repr(r.clients), file=sys.stderr)
assert entry not in r.clients # an op,i pair should be unique assert entry not in r.clients # an op,i pair should be unique
if not r.clients: if not r.clients:
if prune: if prune:
self.__prune_r__([r], reason) self.__prune_r__(r, reason)
return False return False
return True return True
return False return False
### import ### ### import ###
def __import_r__(self, variables, reason): def __import_r__(self, variable, reason):
global NullType global NullType
if NullType is None: if NullType is None:
from null_type import NullType from null_type import NullType
# Imports the owners of the variables # Imports the owners of the variables
for apply_node in [r.owner for r in variables if r.owner is not None]: if variable.owner and variable.owner not in self.apply_nodes:
if apply_node not in self.apply_nodes: self.__import__(variable.owner, reason=reason)
self.__import__(apply_node, reason=reason) if (variable.owner is None and
for r in variables: not isinstance(variable, graph.Constant) and
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: variable not in self.inputs):
if isinstance(r.type, NullType): if isinstance(variable.type, NullType):
raise TypeError("Computation graph contains a NaN. " + raise TypeError("Computation graph contains a NaN. " +
r.type.why_null) variable.type.why_null)
raise MissingInputError("Undeclared input", r) raise MissingInputError("Undeclared input", variable)
if not getattr(r, 'fgraph', None) is self: if not getattr(variable, 'fgraph', None) is self:
self.__setup_r__(r) self.__setup_r__(variable)
self.variables.add(r) self.variables.add(variable)
def __import__(self, apply_node, check=True, reason=None): def __import__(self, apply_node, check=True, reason=None):
node = apply_node
# We import the nodes in topological order. We only are interested # We import the nodes in topological order. We only are interested
# in new nodes, so we use all variables we know of as if they were the input set. # in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to # (the functions in the graph module only use the input set to
# know where to stop going down) # know where to stop going down)
new_nodes = graph.io_toposort(self.variables, node.outputs) new_nodes = graph.io_toposort(self.variables, apply_node.outputs)
if check: if check:
for node in new_nodes: for node in new_nodes:
...@@ -376,34 +373,58 @@ class FunctionGraph(utils.object2): ...@@ -376,34 +373,58 @@ class FunctionGraph(utils.object2):
self.execute_callbacks('on_import', node, reason) self.execute_callbacks('on_import', node, reason)
### prune ### ### prune ###
def __prune_r__(self, variables, reason=None): def __prune_r__(self, variable, reason=None):
"""Should be called for variable that aren't used anymore:
len(var.clients) == 0
This do not mean we will remove it from fgraph.variables. If
the owner stay in the fgraph as other outputs are still used,
the variable will be stay in fgraph.variables.
"""
# Prunes the owners of the variables. # Prunes the owners of the variables.
for node in set(r.owner for r in variables if r.owner is not None): if variable.owner:
self.__prune__(node, reason) self.__prune__(variable.owner, reason)
for r in variables: # variable should not have any clients.
if not r.clients and r in self.variables: # assert not variable.clients
self.variables.remove(r)
# variable should be in self.variables
# Why this assert fail? Making it True could cause opt speed up
# I think this is caused as we remove var in self.variables in
# another place.
# assert variable in self.variables
if variable in self.variables:
# If the owner have other outputs still used,
# then we must keep that variable in the graph.
if not variable.owner or not any(
[var for var in variable.owner.outputs
if var.clients]):
self.variables.remove(variable)
def __prune__(self, apply_node, reason=None): def __prune__(self, apply_node, reason=None):
node = apply_node """Always called on owner of pruned variable from the graph.
if node not in self.apply_nodes:
raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node) This do not mean we will remove it from the graph. If other
assert node.fgraph is self outputs are still used, we will keep the node in the graph.
# If node's outputs have no clients, removes it from the graph
"""
# If apply_node's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one # and recursively tries to prune its inputs. If at least one
# of the op's outputs is an output to the graph or has a client # of the op's outputs is an output to the graph or has a client
# then __prune__ is a no-op. # then __prune__ is a no-op.
for output in node.outputs: for output in apply_node.outputs:
# Cannot prune an op which is an output or used somewhere # Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: # output in self.outputs or self.clients(output): if output.clients or output in self.outputs:
return return
self.apply_nodes.remove(node) self.apply_nodes.remove(apply_node)
self.variables.difference_update(node.outputs) self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', node, reason) self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(node.inputs): for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(node, i)], reason=reason) self.__remove_clients__(input, [(apply_node, i)], reason=reason)
# self.__prune_r__(node.inputs) # self.__prune_r__(apply_node.inputs)
### change input ### ### change input ###
def change_input(self, node, i, new_r, reason=None): def change_input(self, node, i, new_r, reason=None):
...@@ -438,7 +459,7 @@ class FunctionGraph(utils.object2): ...@@ -438,7 +459,7 @@ class FunctionGraph(utils.object2):
if r is new_r: if r is new_r:
return return
self.__import_r__([new_r], reason=reason) self.__import_r__(new_r, reason=reason)
self.__add_clients__(new_r, [(node, i)]) self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False) prune = self.__remove_clients__(r, [(node, i)], False)
# Precondition: the substitution is semantically valid # Precondition: the substitution is semantically valid
...@@ -448,7 +469,7 @@ class FunctionGraph(utils.object2): ...@@ -448,7 +469,7 @@ class FunctionGraph(utils.object2):
r, new_r, reason=reason) r, new_r, reason=reason)
if prune: if prune:
self.__prune_r__([r], reason=reason) self.__prune_r__(r, reason=reason)
### replace ### ### replace ###
def replace(self, r, new_r, reason=None, verbose=None): def replace(self, r, new_r, reason=None, verbose=None):
...@@ -549,8 +570,9 @@ class FunctionGraph(utils.object2): ...@@ -549,8 +570,9 @@ class FunctionGraph(utils.object2):
""" """
try: try:
# Why do we catch the exeception anyway?
self._features.remove(feature) self._features.remove(feature)
except Exception: except ValueError:
return return
detach = getattr(feature, 'on_detach', None) detach = getattr(feature, 'on_detach', None)
if detach is not None: if detach is not None:
...@@ -654,10 +676,6 @@ class FunctionGraph(utils.object2): ...@@ -654,10 +676,6 @@ class FunctionGraph(utils.object2):
ords[node] = list(OrderedSet(prereqs)) ords[node] = list(OrderedSet(prereqs))
return ords return ords
def nclients(self, r):
"""WRITEME Same as len(self.clients(r))."""
return len(self.clients(r))
def check_integrity(self): def check_integrity(self):
"""WRITEME """WRITEME
Call this for a diagnosis if things go awry. Call this for a diagnosis if things go awry.
......
import os
import pickle import pickle
import sys
import unittest import unittest
from nose.plugins.skip import SkipTest
import theano import theano
from theano.compat.six import PY3
from theano.gof import CachedConstantError, FunctionGraph from theano.gof import CachedConstantError, FunctionGraph
from theano import tensor as tt from theano import tensor as tt
...@@ -24,3 +29,23 @@ class TFunctionGraph(unittest.TestCase): ...@@ -24,3 +29,23 @@ class TFunctionGraph(unittest.TestCase):
s = pickle.dumps(func) s = pickle.dumps(func)
func2 = pickle.loads(s) func2 = pickle.loads(s)
def test_node_outputs_not_used(self):
"""In the past, we where removing some not used variable from
fgraph.variables event if the apply had other output used in
the graph. This caused a crash.
This test run the pickle that reproduce this case.
"""
if sys.version_info[:2] < (2, 7):
raise SkipTest("This test need python 2.7 or more recent.")
with open(os.path.join(os.path.dirname(__file__),
'test_fg_old_crash.pkl'),
'rb') as f:
from theano.misc.pkl_utils import CompatUnpickler
if PY3:
u = CompatUnpickler(f, encoding="latin1")
else:
u = CompatUnpickler(f)
d = u.load()
f = theano.function(**d)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论