提交 848b62bd authored 作者: goodfeli's avatar goodfeli

Merge pull request #3 from jaberg/unimp_undef_grad

Using a Feature to catch BadGradOp
...@@ -15,7 +15,6 @@ import numpy ...@@ -15,7 +15,6 @@ import numpy
import theano import theano
from theano import gof from theano import gof
from theano.gof.python25 import partial from theano.gof.python25 import partial
from theano.gradient import check_for_bad_grad
import mode as mode_module import mode as mode_module
from io import In, SymbolicInput, SymbolicInputKit, SymbolicOutput from io import In, SymbolicInput, SymbolicInputKit, SymbolicOutput
...@@ -144,9 +143,14 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -144,9 +143,14 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
fgraph.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and fgraph.destroyers(input))))) fgraph.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and fgraph.destroyers(input)))))
# If named nodes are replaced, keep the name # If named nodes are replaced, keep the name
fgraph.extend(gof.toolbox.PreserveNames()) for feature in std_fgraph.features:
fgraph.extend(feature())
return fgraph, map(SymbolicOutput, updates) return fgraph, map(SymbolicOutput, updates)
std_fgraph.features = [gof.toolbox.PreserveNames]
class AliasedMemoryError(Exception): class AliasedMemoryError(Exception):
"""Memory is aliased that should not be""" """Memory is aliased that should not be"""
pass pass
...@@ -1337,8 +1341,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1337,8 +1341,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
t1 = time.time() t1 = time.time()
mode = mode_module.get_mode(mode) mode = mode_module.get_mode(mode)
check_for_bad_grad(outputs)
inputs = map(convert_function_input, inputs) inputs = map(convert_function_input, inputs)
if outputs is not None: if outputs is not None:
if isinstance(outputs, (list, tuple)): if isinstance(outputs, (list, tuple)):
......
...@@ -70,6 +70,7 @@ from optdb import \ ...@@ -70,6 +70,7 @@ from optdb import \
EquilibriumDB, SequenceDB, ProxyDB EquilibriumDB, SequenceDB, ProxyDB
from toolbox import \ from toolbox import \
Feature, \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\ Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\
PrintListener, ReplacementDidntRemovedError PrintListener, ReplacementDidntRemovedError
......
...@@ -19,6 +19,7 @@ class InconsistencyError(Exception): ...@@ -19,6 +19,7 @@ class InconsistencyError(Exception):
""" """
pass pass
class MissingInputError(Exception): class MissingInputError(Exception):
""" """
A symbolic input needed to compute the outputs is missing. A symbolic input needed to compute the outputs is missing.
...@@ -26,7 +27,6 @@ class MissingInputError(Exception): ...@@ -26,7 +27,6 @@ class MissingInputError(Exception):
pass pass
class FunctionGraph(utils.object2): class FunctionGraph(utils.object2):
""" WRITEME """ WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and a A FunctionGraph represents a subgraph bound by a set of input variables and a
...@@ -46,46 +46,8 @@ class FunctionGraph(utils.object2): ...@@ -46,46 +46,8 @@ class FunctionGraph(utils.object2):
The .clients field combined with the .owner field and the Apply nodes' The .clients field combined with the .owner field and the Apply nodes'
.inputs field allows the graph to be traversed in both directions. .inputs field allows the graph to be traversed in both directions.
It can also be "extended" using function_graph.extend(some_object). See the It can also be "extended" using function_graph.extend(some_object).
toolbox and ext modules for common extensions. See toolbox.Feature for event types and documentation.
Features added with the`extend` function can handle the following events:
- feature.on_attach(function_graph)
Called by extend. The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
- feature.on_detach(function_graph)
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
- feature.on_import(function_graph, node)*
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
- feature.on_prune(function_graph, node)*
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
- feature.on_change_input(function_graph, node, i, r, new_r, [reason=None])*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(function_graph)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
- feature.on_setup_node(function_graph, node):
WRITEME
- feature.on_setup_variable(function_graph, variable):
WRITEME
Historically, the FunctionGraph was called an Env. Keep this in mind Historically, the FunctionGraph was called an Env. Keep this in mind
while reading out-of-date documentation, e-mail support threads, etc. while reading out-of-date documentation, e-mail support threads, etc.
......
...@@ -19,7 +19,61 @@ class ReplacementDidntRemovedError(Exception): ...@@ -19,7 +19,61 @@ class ReplacementDidntRemovedError(Exception):
pass pass
class Bookkeeper: class Feature(object):
"""
Base class for FunctionGraph extensions.
See toolbox and ext modules for common extensions.
"""
def on_attach(self, function_graph):
"""
Called by extend. The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
"""
def on_detach(self, function_graph):
"""
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
"""
def on_import(self, function_graph, node):
"""
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
"""
def on_prune(self, function_graph, node):
"""
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
"""
def on_change_input(self, function_graph, node, i, r, new_r, reason=None):
"""
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
def orderings(self, function_graph):
"""
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
return {}
class Bookkeeper(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
...@@ -30,7 +84,7 @@ class Bookkeeper: ...@@ -30,7 +84,7 @@ class Bookkeeper:
self.on_prune(fgraph, node) self.on_prune(fgraph, node)
class History: class History(Feature):
def __init__(self): def __init__(self):
self.history = {} self.history = {}
...@@ -69,7 +123,7 @@ class History: ...@@ -69,7 +123,7 @@ class History:
self.history[fgraph] = h self.history[fgraph] = h
class Validator: class Validator(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for attr in ('validate', 'validate_time'): for attr in ('validate', 'validate_time'):
...@@ -224,7 +278,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -224,7 +278,7 @@ class NodeFinder(dict, Bookkeeper):
return all return all
class PrintListener(object): class PrintListener(Feature):
def __init__(self, active=True): def __init__(self, active=True):
self.active = active self.active = active
...@@ -251,7 +305,9 @@ class PrintListener(object): ...@@ -251,7 +305,9 @@ class PrintListener(object):
node, i, r, new_r) node, i, r, new_r)
class PreserveNames: class PreserveNames(Feature):
def on_change_input(self, fgraph, mode, i, r, new_r, reason=None): def on_change_input(self, fgraph, mode, i, r, new_r, reason=None):
if r.name is not None and new_r.name is None: if r.name is not None and new_r.name is None:
new_r.name = r.name new_r.name = r.name
...@@ -194,6 +194,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -194,6 +194,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
gmap[r] = g_r gmap[r] = g_r
return gmap return gmap
class BadGradOp(gof.Op): class BadGradOp(gof.Op):
""" """
An Op representing a gradient that cannot be computed. An Op representing a gradient that cannot be computed.
...@@ -239,6 +240,7 @@ class BadGradOp(gof.Op): ...@@ -239,6 +240,7 @@ class BadGradOp(gof.Op):
def raise_exc(self): def raise_exc(self):
raise self.exc(self.msg) raise self.exc(self.msg)
class GradNotImplementedOp(BadGradOp): class GradNotImplementedOp(BadGradOp):
""" A BadGradOp representing a gradient that hasn't been implemented yet. """ A BadGradOp representing a gradient that hasn't been implemented yet.
""" """
...@@ -261,6 +263,7 @@ class GradNotImplementedOp(BadGradOp): ...@@ -261,6 +263,7 @@ class GradNotImplementedOp(BadGradOp):
"%s does not implement its gradient with respect to input %d" \ "%s does not implement its gradient with respect to input %d" \
% (str(type(op)), x_pos)) % (str(type(op)), x_pos))
def grad_not_implemented(op, x_pos, x): def grad_not_implemented(op, x_pos, x):
""" """
Return an un-computable symbolic variable of type `x.type`. Return an un-computable symbolic variable of type `x.type`.
...@@ -274,59 +277,18 @@ def grad_not_implemented(op, x_pos, x): ...@@ -274,59 +277,18 @@ def grad_not_implemented(op, x_pos, x):
return GradNotImplementedOp(op, x_pos)(x) return GradNotImplementedOp(op, x_pos)(x)
def check_for_bad_grad( variables ):
"""
variables: A gof.Variable or list thereof
Raises an exception if any of the variables represents
an expression involving a BadGradOp
"""
#implemented using a deque rather than recursion because python recursion
#limit is set low by default
#handle the case where var is a theano.compile.io.SymbolicOutput
if hasattr(variables,'variable'):
variables = [ variables.variable ]
if not (isinstance(variables, list) or \
isinstance(variables, gof.Variable)):
raise TypeError("Expected gof.Variable or list thereof, got "+\
str(type(variables)))
if not isinstance(variables,list):
variables = [ variables ]
vars_to_check = deque(variables) def raise_if_bad_grad(node):
already_checked = set([]) if node is not None:
if isinstance(node.op, BadGradOp):
while True: op.raise_exc()
try:
var = vars_to_check.pop()
except IndexError:
break
if var not in already_checked:
already_checked.update([var])
#handle the case where var is a theano.compile.io.SymbolicOutput
if hasattr(var, 'variable'):
var = var.variable
if not isinstance(var, gof.Variable):
raise TypeError("Expected gof.Variable, got "+str(type(var)))
node = var.owner class BadGradFeature(gof.Feature):
def on_import(self, fgraph, node):
raise_if_bad_grad(node)
if node is not None: theano.compile.function_module.std_fgraph.features.append(BadGradFeature)
op = node.op
if isinstance(op, BadGradOp):
op.raise_exc()
vars_to_check.extendleft(node.inputs)
#end if node is not None
#end if not already_checked
#end while
######################## ########################
...@@ -648,7 +610,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -648,7 +610,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
and ret[-1].name is None: and ret[-1].name is None:
ret[-1].name = '(d%s/d%s)' % (cost.name, p.name) ret[-1].name = '(d%s/d%s)' % (cost.name, p.name)
check_for_bad_grad(ret) # new_vars is meant to be a list of all variables created
# by this call to grad(), which will be visible to the caller
# after we return.
new_vars = graph.ancestors(ret,
blockers=graph.ancestors(cost) + list(wrt))
map(raise_if_bad_grad, [v.owner for v in new_vars])
return format_as(using_list, using_tuple, ret) return format_as(using_list, using_tuple, ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论