提交 487d25ed authored 作者: nouiz's avatar nouiz

Merge pull request #878 from goodfeli/unimp_undef_grad

Adds correct handling of unimplemented gradients
......@@ -24,7 +24,8 @@ from theano.compile.function_module import (FunctionMaker,
infer_reuse_pattern,
SymbolicInputKit,
SymbolicOutput,
Supervisor)
Supervisor,
std_fgraph)
from theano.compile.mode import Mode, register_mode
AddConfigVar('DebugMode.patience',
......@@ -684,8 +685,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
if not (spec.mutable or (hasattr(fgraph, 'destroyers')
and fgraph.destroyers(input)))))
# If named nodes are replaced, keep the name
fgraph.extend(gof.toolbox.PreserveNames())
for feature in std_fgraph.features:
fgraph.extend(feature)
return fgraph, map(SymbolicOutput, updates), equivalence_tracker
......
......@@ -13,7 +13,6 @@ from profiling import ProfileStats
from pfunc import pfunc
from numpy import any # to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, profile=None,
......@@ -192,6 +191,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode=mode,
accept_inplace=accept_inplace, name=name)
else:
#note: pfunc will also call orig_function-- orig_function is a choke point
# that all compilation must pass through
fn = pfunc(params=inputs,
outputs=outputs,
mode=mode,
......
......@@ -143,9 +143,31 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
fgraph.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and fgraph.destroyers(input)))))
# If named nodes are replaced, keep the name
fgraph.extend(gof.toolbox.PreserveNames())
for feature in std_fgraph.features:
fgraph.extend(feature())
return fgraph, map(SymbolicOutput, updates)
std_fgraph.features = [gof.toolbox.PreserveNames()]
class UncomputableFeature(gof.Feature):
"""A feature that ensures the graph never contains any
uncomputable nodes. This check must be made at compile time
rather than runtime in order to make sure that NaN nodes are
not optimized out. It must be done as a Feature so that
the fgraph will continually check that optimizations have
not introduce any uncomputable nodes."""
def on_attach(self, fgraph):
for node in fgraph.nodes:
return self.on_import(fgraph, node)
def on_import(self, fgraph, node):
gof.op.raise_if_uncomputable(node)
std_fgraph.features.append(UncomputableFeature())
class AliasedMemoryError(Exception):
"""Memory is aliased that should not be"""
pass
......
......@@ -70,6 +70,7 @@ from optdb import \
EquilibriumDB, SequenceDB, ProxyDB
from toolbox import \
Feature, \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\
PrintListener, ReplacementDidntRemovedError
......
......@@ -19,6 +19,7 @@ class InconsistencyError(Exception):
"""
pass
class MissingInputError(Exception):
"""
A symbolic input needed to compute the outputs is missing.
......@@ -26,7 +27,6 @@ class MissingInputError(Exception):
pass
class FunctionGraph(utils.object2):
""" WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and a
......@@ -46,46 +46,8 @@ class FunctionGraph(utils.object2):
The .clients field combined with the .owner field and the Apply nodes'
.inputs field allows the graph to be traversed in both directions.
It can also be "extended" using function_graph.extend(some_object). See the
toolbox and ext modules for common extensions.
Features added with the`extend` function can handle the following events:
- feature.on_attach(function_graph)
Called by extend. The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
- feature.on_detach(function_graph)
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
- feature.on_import(function_graph, node)*
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
- feature.on_prune(function_graph, node)*
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
- feature.on_change_input(function_graph, node, i, r, new_r, [reason=None])*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(function_graph)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
- feature.on_setup_node(function_graph, node):
WRITEME
- feature.on_setup_variable(function_graph, variable):
WRITEME
It can also be "extended" using function_graph.extend(some_object).
See toolbox.Feature for event types and documentation.
Historically, the FunctionGraph was called an Env. Keep this in mind
while reading out-of-date documentation, e-mail support threads, etc.
......@@ -472,13 +434,19 @@ class FunctionGraph(utils.object2):
# takes a sequence, and since this is a kind of container you
# would expect it to do similarly.
def extend(self, feature):
"""WRITEME
Adds a feature to this function_graph. The feature may define one
or more of the following methods:
"""
Adds a gof.toolbox.Feature to this function_graph
and triggers its on_attach callback
"""
if feature in self._features:
return # the feature is already present
#it would be nice if we could require a specific class instead of
#a "workalike" so we could do actual error checking
#if not isinstance(feature, toolbox.Feature):
# raise TypeError("Expected gof.toolbox.Feature instance, got "+\
# str(type(feature)))
attach = getattr(feature, 'on_attach', None)
if attach is not None:
try:
......
......@@ -606,6 +606,56 @@ class Op(utils.object2, PureOp, CLinkerOp):
rval.lazy = False
return rval
class UncomputableOp(Op):
"""
An Op representing an expression that cannot be computed.
theano.function checks that the subgraph it implements
does not contain these ops, and that optimization does not
introduce any such ops.
theano.tensor.grad checks the graphs it returns to ensure
they do not contain these ops.
"""
def __init__(self, exc, msg=""):
"""
exc: the exception type to raise if a subgraph contains
this op.
msg: the message to include in the exception.
"""
self.exc = exc
self.msg = msg
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash((type(self)))
def __str__(self):
return "Uncomputable{%s,%s}"%(self.exc,self.msg)
def make_node(self,x):
return graph.Apply(self, [x], [x.type()] )
def perform(self, node, inputs, out_storage):
""" This should never be called"""
raise AssertionError("A BadGradOp should never be compiled, "+\
"and certainly not executed.")
#Note: essentially, this op should just be NaNs_like(inputs[0])
#but 0 * BadGradOp(x) + y optimizes to just y
#so until we develop a way of symbolically representing a variable
#that is always NaN and implement the logic for 0 * NaN = NaN, etc.
#the only way we can guarantee correctness of a theano function
#is to guarantee that its initial subgraph contained no BadGradOps
def raise_exc(self):
raise self.exc(self.msg)
def raise_if_uncomputable(node):
if node is not None:
if isinstance(node.op, UncomputableOp):
node.op.raise_exc()
def get_test_value(v):
"""
......
......@@ -19,7 +19,72 @@ class ReplacementDidntRemovedError(Exception):
pass
class Bookkeeper:
class Feature(object):
"""
Base class for FunctionGraph extensions.
A Feature is an object with several callbacks that are triggered
by various operations on FunctionGraphs. It can be used to enforce
graph properties at all stages of graph optimization.
See toolbox and ext modules for common extensions.
"""
def on_attach(self, function_graph):
"""
Called by FunctionGraph.extend, the method that attaches the feature
to the FunctionGraph. Since this is called after the FunctionGraph
is initially populated, this is where you should run checks on the
initial contents of the FunctionGraph.
The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
"""
def on_detach(self, function_graph):
"""
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
"""
def on_import(self, function_graph, node):
"""
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
Note: on_import is not called when the graph is created. If you
want to detect the first nodes to be implemented to the graph,
you should do this by implementing on_attach.
"""
def on_prune(self, function_graph, node):
"""
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
"""
def on_change_input(self, function_graph, node, i, r, new_r, reason=None):
"""
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
def orderings(self, function_graph):
"""
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
return {}
class Bookkeeper(Feature):
def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
......@@ -30,7 +95,7 @@ class Bookkeeper:
self.on_prune(fgraph, node)
class History:
class History(Feature):
def __init__(self):
self.history = {}
......@@ -69,7 +134,7 @@ class History:
self.history[fgraph] = h
class Validator:
class Validator(Feature):
def on_attach(self, fgraph):
for attr in ('validate', 'validate_time'):
......@@ -224,7 +289,7 @@ class NodeFinder(dict, Bookkeeper):
return all
class PrintListener(object):
class PrintListener(Feature):
def __init__(self, active=True):
self.active = active
......@@ -251,7 +316,10 @@ class PrintListener(object):
node, i, r, new_r)
class PreserveNames:
class PreserveNames(Feature):
def on_change_input(self, fgraph, mode, i, r, new_r, reason=None):
if r.name is not None and new_r.name is None:
new_r.name = r.name
"""Driver for gradient calculations."""
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron"
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
__copyright__ = "(c) 2011, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
......@@ -11,9 +11,9 @@ import __builtin__
import logging
import warnings
_logger = logging.getLogger('theano.gradient')
import sys
import numpy # for numeric_grad
from collections import deque
import theano
from theano.raise_op import Raise
......@@ -195,19 +195,42 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
return gmap
def unimplemented_grad(op, x_pos, x):
class GradNotImplementedOp(gof.op.UncomputableOp):
""" A BadGradOp representing a gradient that hasn't been implemented yet.
"""
DO NOT USE. Remove this function after all usage of it has been
removed from theano.
def __init__(self, op, x_pos):
"""
op: A theano op whose grad is not implemented for some input
x_pos: An int, giving the index in the op's input list of
a variable for which the gradient is not implemented
(if op has unimplemented gradients for several inputs,
it must still return a separate UnimplementedGradOp for
each)
"""
assert isinstance(op, gof.Op)
assert isinstance(x_pos, int)
assert x_pos >= 0
super(GradNotImplementedOp,self).__init__(NotImplementedError,
"%s does not implement its gradient with respect to input %d" \
% (str(type(op)), x_pos))
def grad_not_implemented(op, x_pos, x):
"""
Return an un-computable symbolic variable of type `x.type`.
If any function tries to compute this un-computable variable, an exception
(NotImplementedError) will be raised indicating that the gradient on the
`x_pos`'th input of `op` has not been implemented.
If any call to tensor.grad results in an expression containing this
un-computable variable, an exception (NotImplementedError) will be
raised indicating that the gradient on the
`x_pos`'th input of `op` has not been implemented. Likewise if
any call to theano.function involves this variable.
"""
msg = '%s.grad not implemented for input %i' % (op, x_pos)
return Raise(msg=msg)(x)
return GradNotImplementedOp(op, x_pos)(x)
########################
# R Operator
......@@ -528,6 +551,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
and ret[-1].name is None:
ret[-1].name = '(d%s/d%s)' % (cost.name, p.name)
# new_vars is meant to be a list of all variables created
# by this call to grad(), which will be visible to the caller
# after we return.
new_vars = gof.graph.ancestors(ret,
blockers=gof.graph.ancestors([cost]) + list(wrt))
map(gof.op.raise_if_uncomputable, [v.owner for v in new_vars])
return format_as(using_list, using_tuple, ret)
......
......@@ -3,6 +3,7 @@ from theano import Op, Apply
import theano.tensor as T
from theano.gof import local_optimizer
from theano.sandbox.cuda import cuda_available, GpuOp
from theano.gradient import grad_not_implemented
if cuda_available:
from theano.sandbox.cuda import CudaNdarrayType
......@@ -10,14 +11,6 @@ if cuda_available:
from theano.sandbox.cuda.opt import register_opt as register_gpu_opt
class BadOldCode(Exception):
"""
We create a specific Exception to be sure it does not get caught by
mistake.
"""
pass
class Images2Neibs(Op):
def __init__(self, mode='valid'):
"""
......@@ -91,15 +84,8 @@ class Images2Neibs(Op):
def grad(self, inp, grads):
x, neib_shape, neib_step = inp
gz, = grads
if self.mode in ['valid', 'ignore_borders']:
raise BadOldCode("The Images2Neibs grad is not implemented."
" It was in the past, but returned the wrong"
" answer!")
# This is the reverse of the op, not the grad!
return [neibs2images(gz, neib_shape, x.shape, mode=self.mode),
None, None]
else:
raise NotImplementedError()
return [ grad_not_implemented(self, i, ip) \
for i, ip in enumerate(inp) ]
def c_code_cache_version(self):
return (5,)
......
......@@ -251,13 +251,36 @@ class test_grad_sources_inputs(unittest.TestCase):
self.assertTrue(g[a1.inputs[0]] == 6)
self.assertTrue(g[a1.inputs[1]] == 11)
def test_unimplemented_grad():
def test_unimplemented_grad_func():
#tests that function compilation catches unimplemented grads in the graph
a = theano.tensor.vector()
b = theano.gradient.unimplemented_grad(theano.tensor.add, 1, a)
f = theano.function([a], b)
b = theano.gradient.grad_not_implemented(theano.tensor.add, 0, a)
try:
f([1,2,3])
f = theano.function([a], b)
assert 0
#Note: it's important that the NotImplementedGradOp is caught
#at COMPILATION time, not execution time.
#If the uncomputable variable is, for example, multiplied by 0,
#it could be optimized out of the final graph.
except NotImplementedError:
pass
def test_unimplemented_grad_grad():
#tests that unimplemented grads are caught in the grad method
class DummyOp(gof.Op):
def make_node(self, x):
return gof.Apply(self, [x], [x.type()])
def grad(self, inputs, output_grads):
return [ theano.gradient.grad_not_implemented(self, 0, inputs[0]) ]
a = theano.tensor.scalar()
b = DummyOp()(a)
try:
g = theano.gradient.grad(b,a)
assert False
except NotImplementedError:
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论