提交 848b62bd authored 作者: goodfeli's avatar goodfeli

Merge pull request #3 from jaberg/unimp_undef_grad

Using a Feature to catch BadGradOp
......@@ -15,7 +15,6 @@ import numpy
import theano
from theano import gof
from theano.gof.python25 import partial
from theano.gradient import check_for_bad_grad
import mode as mode_module
from io import In, SymbolicInput, SymbolicInputKit, SymbolicOutput
......@@ -144,9 +143,14 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
fgraph.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and fgraph.destroyers(input)))))
# If named nodes are replaced, keep the name
fgraph.extend(gof.toolbox.PreserveNames())
for feature in std_fgraph.features:
fgraph.extend(feature())
return fgraph, map(SymbolicOutput, updates)
std_fgraph.features = [gof.toolbox.PreserveNames]
class AliasedMemoryError(Exception):
"""Memory is aliased that should not be"""
pass
......@@ -1337,8 +1341,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
t1 = time.time()
mode = mode_module.get_mode(mode)
check_for_bad_grad(outputs)
inputs = map(convert_function_input, inputs)
if outputs is not None:
if isinstance(outputs, (list, tuple)):
......
......@@ -70,6 +70,7 @@ from optdb import \
EquilibriumDB, SequenceDB, ProxyDB
from toolbox import \
Feature, \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\
PrintListener, ReplacementDidntRemovedError
......
......@@ -19,6 +19,7 @@ class InconsistencyError(Exception):
"""
pass
class MissingInputError(Exception):
"""
A symbolic input needed to compute the outputs is missing.
......@@ -26,7 +27,6 @@ class MissingInputError(Exception):
pass
class FunctionGraph(utils.object2):
""" WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and a
......@@ -46,46 +46,8 @@ class FunctionGraph(utils.object2):
The .clients field combined with the .owner field and the Apply nodes'
.inputs field allows the graph to be traversed in both directions.
It can also be "extended" using function_graph.extend(some_object). See the
toolbox and ext modules for common extensions.
Features added with the`extend` function can handle the following events:
- feature.on_attach(function_graph)
Called by extend. The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
- feature.on_detach(function_graph)
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
- feature.on_import(function_graph, node)*
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
- feature.on_prune(function_graph, node)*
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
- feature.on_change_input(function_graph, node, i, r, new_r, [reason=None])*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(function_graph)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
- feature.on_setup_node(function_graph, node):
WRITEME
- feature.on_setup_variable(function_graph, variable):
WRITEME
It can also be "extended" using function_graph.extend(some_object).
See toolbox.Feature for event types and documentation.
Historically, the FunctionGraph was called an Env. Keep this in mind
while reading out-of-date documentation, e-mail support threads, etc.
......
......@@ -19,7 +19,61 @@ class ReplacementDidntRemovedError(Exception):
pass
class Bookkeeper:
class Feature(object):
"""
Base class for FunctionGraph extensions.
See toolbox and ext modules for common extensions.
"""
def on_attach(self, function_graph):
"""
Called by extend. The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
"""
def on_detach(self, function_graph):
"""
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
"""
def on_import(self, function_graph, node):
"""
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
"""
def on_prune(self, function_graph, node):
"""
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
"""
def on_change_input(self, function_graph, node, i, r, new_r, reason=None):
"""
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
def orderings(self, function_graph):
"""
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
return {}
class Bookkeeper(Feature):
def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
......@@ -30,7 +84,7 @@ class Bookkeeper:
self.on_prune(fgraph, node)
class History:
class History(Feature):
def __init__(self):
self.history = {}
......@@ -69,7 +123,7 @@ class History:
self.history[fgraph] = h
class Validator:
class Validator(Feature):
def on_attach(self, fgraph):
for attr in ('validate', 'validate_time'):
......@@ -224,7 +278,7 @@ class NodeFinder(dict, Bookkeeper):
return all
class PrintListener(object):
class PrintListener(Feature):
def __init__(self, active=True):
self.active = active
......@@ -251,7 +305,9 @@ class PrintListener(object):
node, i, r, new_r)
class PreserveNames:
class PreserveNames(Feature):
def on_change_input(self, fgraph, mode, i, r, new_r, reason=None):
if r.name is not None and new_r.name is None:
new_r.name = r.name
......@@ -194,6 +194,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
gmap[r] = g_r
return gmap
class BadGradOp(gof.Op):
"""
An Op representing a gradient that cannot be computed.
......@@ -239,6 +240,7 @@ class BadGradOp(gof.Op):
def raise_exc(self):
raise self.exc(self.msg)
class GradNotImplementedOp(BadGradOp):
""" A BadGradOp representing a gradient that hasn't been implemented yet.
"""
......@@ -261,6 +263,7 @@ class GradNotImplementedOp(BadGradOp):
"%s does not implement its gradient with respect to input %d" \
% (str(type(op)), x_pos))
def grad_not_implemented(op, x_pos, x):
"""
Return an un-computable symbolic variable of type `x.type`.
......@@ -274,59 +277,18 @@ def grad_not_implemented(op, x_pos, x):
return GradNotImplementedOp(op, x_pos)(x)
def check_for_bad_grad( variables ):
"""
variables: A gof.Variable or list thereof
Raises an exception if any of the variables represents
an expression involving a BadGradOp
"""
#implemented using a deque rather than recursion because python recursion
#limit is set low by default
#handle the case where var is a theano.compile.io.SymbolicOutput
if hasattr(variables,'variable'):
variables = [ variables.variable ]
if not (isinstance(variables, list) or \
isinstance(variables, gof.Variable)):
raise TypeError("Expected gof.Variable or list thereof, got "+\
str(type(variables)))
if not isinstance(variables,list):
variables = [ variables ]
vars_to_check = deque(variables)
already_checked = set([])
while True:
try:
var = vars_to_check.pop()
except IndexError:
break
if var not in already_checked:
already_checked.update([var])
#handle the case where var is a theano.compile.io.SymbolicOutput
if hasattr(var, 'variable'):
var = var.variable
def raise_if_bad_grad(node):
if node is not None:
if isinstance(node.op, BadGradOp):
op.raise_exc()
if not isinstance(var, gof.Variable):
raise TypeError("Expected gof.Variable, got "+str(type(var)))
node = var.owner
class BadGradFeature(gof.Feature):
def on_import(self, fgraph, node):
raise_if_bad_grad(node)
if node is not None:
op = node.op
if isinstance(op, BadGradOp):
op.raise_exc()
vars_to_check.extendleft(node.inputs)
#end if node is not None
#end if not already_checked
#end while
theano.compile.function_module.std_fgraph.features.append(BadGradFeature)
########################
......@@ -648,7 +610,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
and ret[-1].name is None:
ret[-1].name = '(d%s/d%s)' % (cost.name, p.name)
check_for_bad_grad(ret)
# new_vars is meant to be a list of all variables created
# by this call to grad(), which will be visible to the caller
# after we return.
new_vars = graph.ancestors(ret,
blockers=graph.ancestors(cost) + list(wrt))
map(raise_if_bad_grad, [v.owner for v in new_vars])
return format_as(using_list, using_tuple, ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论