提交 a09fbf2a authored 作者: Ian Goodfellow's avatar Ian Goodfellow

Merge remote-tracking branch 'origin/unimp_undef_grad' into unimp_undef_grad

Conflicts: theano/gradient.py
...@@ -15,7 +15,6 @@ import numpy ...@@ -15,7 +15,6 @@ import numpy
import theano import theano
from theano import gof from theano import gof
from theano.gof.python25 import partial from theano.gof.python25 import partial
from theano.gradient import check_for_bad_grad
import mode as mode_module import mode as mode_module
from io import In, SymbolicInput, SymbolicInputKit, SymbolicOutput from io import In, SymbolicInput, SymbolicInputKit, SymbolicOutput
...@@ -144,9 +143,31 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -144,9 +143,31 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
fgraph.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and fgraph.destroyers(input))))) fgraph.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and fgraph.destroyers(input)))))
# If named nodes are replaced, keep the name # If named nodes are replaced, keep the name
fgraph.extend(gof.toolbox.PreserveNames()) for feature in std_fgraph.features:
fgraph.extend(feature())
return fgraph, map(SymbolicOutput, updates) return fgraph, map(SymbolicOutput, updates)
std_fgraph.features = [gof.toolbox.PreserveNames]
class UncomputableFeature(gof.Feature):
"""A feature that ensures the graph never contains any
uncomputable nodes. This check must be made at compile time
rather than runtime in order to make sure that NaN nodes are
not optimized out. It must be done as a Feature so that
the fgraph will continually check that optimizations have
not introduce any uncomputable nodes."""
def on_attach(self, fgraph):
for node in fgraph.nodes:
return self.on_import(fgraph, node)
def on_import(self, fgraph, node):
gof.op.raise_if_uncomputable(node)
std_fgraph.features.append(UncomputableFeature)
class AliasedMemoryError(Exception): class AliasedMemoryError(Exception):
"""Memory is aliased that should not be""" """Memory is aliased that should not be"""
pass pass
...@@ -1337,8 +1358,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1337,8 +1358,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
t1 = time.time() t1 = time.time()
mode = mode_module.get_mode(mode) mode = mode_module.get_mode(mode)
check_for_bad_grad(outputs)
inputs = map(convert_function_input, inputs) inputs = map(convert_function_input, inputs)
if outputs is not None: if outputs is not None:
if isinstance(outputs, (list, tuple)): if isinstance(outputs, (list, tuple)):
......
...@@ -70,6 +70,7 @@ from optdb import \ ...@@ -70,6 +70,7 @@ from optdb import \
EquilibriumDB, SequenceDB, ProxyDB EquilibriumDB, SequenceDB, ProxyDB
from toolbox import \ from toolbox import \
Feature, \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\ Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\
PrintListener, ReplacementDidntRemovedError PrintListener, ReplacementDidntRemovedError
......
...@@ -19,6 +19,7 @@ class InconsistencyError(Exception): ...@@ -19,6 +19,7 @@ class InconsistencyError(Exception):
""" """
pass pass
class MissingInputError(Exception): class MissingInputError(Exception):
""" """
A symbolic input needed to compute the outputs is missing. A symbolic input needed to compute the outputs is missing.
...@@ -26,7 +27,6 @@ class MissingInputError(Exception): ...@@ -26,7 +27,6 @@ class MissingInputError(Exception):
pass pass
class FunctionGraph(utils.object2): class FunctionGraph(utils.object2):
""" WRITEME """ WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and a A FunctionGraph represents a subgraph bound by a set of input variables and a
...@@ -46,46 +46,8 @@ class FunctionGraph(utils.object2): ...@@ -46,46 +46,8 @@ class FunctionGraph(utils.object2):
The .clients field combined with the .owner field and the Apply nodes' The .clients field combined with the .owner field and the Apply nodes'
.inputs field allows the graph to be traversed in both directions. .inputs field allows the graph to be traversed in both directions.
It can also be "extended" using function_graph.extend(some_object). See the It can also be "extended" using function_graph.extend(some_object).
toolbox and ext modules for common extensions. See toolbox.Feature for event types and documentation.
Features added with the`extend` function can handle the following events:
- feature.on_attach(function_graph)
Called by extend. The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
- feature.on_detach(function_graph)
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
- feature.on_import(function_graph, node)*
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
- feature.on_prune(function_graph, node)*
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
- feature.on_change_input(function_graph, node, i, r, new_r, [reason=None])*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(function_graph)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
- feature.on_setup_node(function_graph, node):
WRITEME
- feature.on_setup_variable(function_graph, variable):
WRITEME
Historically, the FunctionGraph was called an Env. Keep this in mind Historically, the FunctionGraph was called an Env. Keep this in mind
while reading out-of-date documentation, e-mail support threads, etc. while reading out-of-date documentation, e-mail support threads, etc.
......
...@@ -606,6 +606,63 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -606,6 +606,63 @@ class Op(utils.object2, PureOp, CLinkerOp):
rval.lazy = False rval.lazy = False
return rval return rval
class UncomputableOp(Op):
"""
An Op representing an expression that cannot be computed.
theano.function checks that the subgraph it implements
does not contain these ops, and that optimization does not
introduce any such ops.
theano.tensor.grad checks the graphs it returns to ensure
they do not contain these ops.
"""
def __init__(self, exc, msg=""):
"""
exc: the exception type to raise if a subgraph contains
this op.
msg: the message to include in the exception.
"""
self.exc = exc
self.msg = msg
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash((type(self)))
def __str__(self):
return "Uncomputable{%s,%s}"%(self.exc,self.msg)
def make_node(self,x):
return graph.Apply(self, [x], [x.type()] )
def perform(self, node, inputs, out_storage):
""" This should never be called"""
raise AssertionError("A BadGradOp should never be compiled, "+\
"and certainly not executed.")
#Note: essentially, this op should just be NaNs_like(inputs[0])
#but 0 * BadGradOp(x) + y optimizes to just y
#so until we develop a way of symbolically representing a variable
#that is always NaN and implement the logic for 0 * NaN = NaN, etc.
#the only way we can guarantee correctness of a theano function
#is to guarantee that its initial subgraph contained no BadGradOps
def raise_exc(self):
raise self.exc(self.msg)
def raise_if_uncomputable(node):
print 'raise_if_computable called on ',node
if node is not None:
print 'node is not None'
if isinstance(node.op, UncomputableOp):
node.op.raise_exc()
else:
print 'node.op is not an UncomputableOp'
print type(node.op)
else:
print 'node is None'
def get_test_value(v): def get_test_value(v):
""" """
......
...@@ -19,7 +19,61 @@ class ReplacementDidntRemovedError(Exception): ...@@ -19,7 +19,61 @@ class ReplacementDidntRemovedError(Exception):
pass pass
class Bookkeeper: class Feature(object):
"""
Base class for FunctionGraph extensions.
See toolbox and ext modules for common extensions.
"""
def on_attach(self, function_graph):
"""
Called by extend. The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
"""
def on_detach(self, function_graph):
"""
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
"""
def on_import(self, function_graph, node):
"""
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
"""
def on_prune(self, function_graph, node):
"""
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
"""
def on_change_input(self, function_graph, node, i, r, new_r, reason=None):
"""
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
def orderings(self, function_graph):
"""
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
return {}
class Bookkeeper(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
...@@ -30,7 +84,7 @@ class Bookkeeper: ...@@ -30,7 +84,7 @@ class Bookkeeper:
self.on_prune(fgraph, node) self.on_prune(fgraph, node)
class History: class History(Feature):
def __init__(self): def __init__(self):
self.history = {} self.history = {}
...@@ -69,7 +123,7 @@ class History: ...@@ -69,7 +123,7 @@ class History:
self.history[fgraph] = h self.history[fgraph] = h
class Validator: class Validator(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for attr in ('validate', 'validate_time'): for attr in ('validate', 'validate_time'):
...@@ -224,7 +278,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -224,7 +278,7 @@ class NodeFinder(dict, Bookkeeper):
return all return all
class PrintListener(object): class PrintListener(Feature):
def __init__(self, active=True): def __init__(self, active=True):
self.active = active self.active = active
...@@ -251,7 +305,9 @@ class PrintListener(object): ...@@ -251,7 +305,9 @@ class PrintListener(object):
node, i, r, new_r) node, i, r, new_r)
class PreserveNames: class PreserveNames(Feature):
def on_change_input(self, fgraph, mode, i, r, new_r, reason=None): def on_change_input(self, fgraph, mode, i, r, new_r, reason=None):
if r.name is not None and new_r.name is None: if r.name is not None and new_r.name is None:
new_r.name = r.name new_r.name = r.name
...@@ -194,52 +194,8 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -194,52 +194,8 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
gmap[r] = g_r gmap[r] = g_r
return gmap return gmap
class BadGradOp(gof.Op):
"""
An Op representing a gradient that cannot be computed.
theano.tensor.grad checks the graphs it returns to ensure
they do not contain these ops.
theano.function also checks that the subgraph it implements
does not contain these ops.
"""
def __init__(self, exc, msg=""):
"""
exc: the exception type to raise if a subgraph contains
this op.
msg: the message to include in the exception.
"""
self.exc = exc class GradNotImplementedOp(gof.op.UncomputableOp):
self.msg = msg
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash((type(self)))
def __str__(self):
return "BadGrad{%s,%s}"%(self.exc,self.msg)
def make_node(self,x):
return gof.Apply(self, [x], [x.type()] )
def perform(self, node, inputs, out_storage):
""" This should never be called"""
raise AssertionError("A BadGradOp should never be compiled, "+\
"and certainly not executed.")
#Note: essentially, this op should just be NaNs_like(inputs[0])
#but 0 * BadGradOp(x) + y optimizes to just y
#so until we develop a way of symbolically representing a variable
#that is always NaN and implement the logic for 0 * NaN = NaN, etc.
#the only way we can guarantee correctness of a theano function
#is to guarantee that its initial subgraph contained no BadGradOps
def raise_exc(self):
raise self.exc(self.msg)
class GradNotImplementedOp(BadGradOp):
""" A BadGradOp representing a gradient that hasn't been implemented yet. """ A BadGradOp representing a gradient that hasn't been implemented yet.
""" """
...@@ -261,6 +217,7 @@ class GradNotImplementedOp(BadGradOp): ...@@ -261,6 +217,7 @@ class GradNotImplementedOp(BadGradOp):
"%s does not implement its gradient with respect to input %d" \ "%s does not implement its gradient with respect to input %d" \
% (str(type(op)), x_pos)) % (str(type(op)), x_pos))
def grad_not_implemented(op, x_pos, x): def grad_not_implemented(op, x_pos, x):
""" """
Return an un-computable symbolic variable of type `x.type`. Return an un-computable symbolic variable of type `x.type`.
...@@ -274,79 +231,6 @@ def grad_not_implemented(op, x_pos, x): ...@@ -274,79 +231,6 @@ def grad_not_implemented(op, x_pos, x):
return GradNotImplementedOp(op, x_pos)(x) return GradNotImplementedOp(op, x_pos)(x)
def check_for_bad_grad( variables ):
"""
variables: A gof.Variable or list thereof
Raises an exception if any of the variables represents
an expression involving a BadGradOp
"""
#preprocess variables to make sure it is a list and make
#sure everything is really a variable and not a
#theano.compile.io.SymbolicOutput
if not isinstance(variables, list):
variables = [ variables ]
for i in xrange(len(variables)):
if not isinstance(variables[i],gof.Variable):
if hasattr(variables[i],'variable') and \
isinstance(variables[i].variable,gof.Variable):
variables[i] = variables[i].variable
for v in gof.graph.ancestors(variables):
if v.owner is not None and isinstance(v.owner.op,BadGradOp):
v.owner.op.raise_exc()
"""
#implemented using a deque rather than recursion because python recursion
#limit is set low by default
#handle the case where var is a theano.compile.io.SymbolicOutput
if hasattr(variables,'variable'):
variables = [ variables.variable ]
if not (isinstance(variables, list) or \
isinstance(variables, gof.Variable)):
raise TypeError("Expected gof.Variable or list thereof, got "+\
str(type(variables)))
if not isinstance(variables,list):
variables = [ variables ]
vars_to_check = deque(variables)
already_checked = set([])
while True:
try:
var = vars_to_check.pop()
except IndexError:
break
if var not in already_checked:
already_checked.update([var])
#handle the case where var is a theano.compile.io.SymbolicOutput
if hasattr(var, 'variable'):
var = var.variable
if not isinstance(var, gof.Variable):
raise TypeError("Expected gof.Variable, got "+str(type(var)))
node = var.owner
if node is not None:
op = node.op
if isinstance(op, BadGradOp):
op.raise_exc()
vars_to_check.extendleft(node.inputs)
#end if node is not None
#end if not already_checked
#end while
"""
######################## ########################
# R Operator # R Operator
...@@ -667,7 +551,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -667,7 +551,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
and ret[-1].name is None: and ret[-1].name is None:
ret[-1].name = '(d%s/d%s)' % (cost.name, p.name) ret[-1].name = '(d%s/d%s)' % (cost.name, p.name)
check_for_bad_grad(ret) # new_vars is meant to be a list of all variables created
# by this call to grad(), which will be visible to the caller
# after we return.
new_vars = gof.graph.ancestors(ret,
blockers=gof.graph.ancestors([cost]) + list(wrt))
map(gof.op.raise_if_uncomputable, [v.owner for v in new_vars])
return format_as(using_list, using_tuple, ret) return format_as(using_list, using_tuple, ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论