提交 487d25ed authored 作者: nouiz's avatar nouiz

Merge pull request #878 from goodfeli/unimp_undef_grad

Adds correct handling of unimplemented gradients
...@@ -24,7 +24,8 @@ from theano.compile.function_module import (FunctionMaker, ...@@ -24,7 +24,8 @@ from theano.compile.function_module import (FunctionMaker,
infer_reuse_pattern, infer_reuse_pattern,
SymbolicInputKit, SymbolicInputKit,
SymbolicOutput, SymbolicOutput,
Supervisor) Supervisor,
std_fgraph)
from theano.compile.mode import Mode, register_mode from theano.compile.mode import Mode, register_mode
AddConfigVar('DebugMode.patience', AddConfigVar('DebugMode.patience',
...@@ -684,8 +685,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -684,8 +685,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
if not (spec.mutable or (hasattr(fgraph, 'destroyers') if not (spec.mutable or (hasattr(fgraph, 'destroyers')
and fgraph.destroyers(input))))) and fgraph.destroyers(input)))))
# If named nodes are replaced, keep the name for feature in std_fgraph.features:
fgraph.extend(gof.toolbox.PreserveNames()) fgraph.extend(feature)
return fgraph, map(SymbolicOutput, updates), equivalence_tracker return fgraph, map(SymbolicOutput, updates), equivalence_tracker
......
...@@ -13,7 +13,6 @@ from profiling import ProfileStats ...@@ -13,7 +13,6 @@ from profiling import ProfileStats
from pfunc import pfunc from pfunc import pfunc
from numpy import any # to work in python 2.4 from numpy import any # to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=None, givens=None, def function(inputs, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, profile=None, rebuild_strict=True, allow_input_downcast=None, profile=None,
...@@ -192,6 +191,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -192,6 +191,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode=mode, mode=mode,
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
else: else:
#note: pfunc will also call orig_function-- orig_function is a choke point
# that all compilation must pass through
fn = pfunc(params=inputs, fn = pfunc(params=inputs,
outputs=outputs, outputs=outputs,
mode=mode, mode=mode,
......
...@@ -143,9 +143,31 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -143,9 +143,31 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
fgraph.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and fgraph.destroyers(input))))) fgraph.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and fgraph.destroyers(input)))))
# If named nodes are replaced, keep the name # If named nodes are replaced, keep the name
fgraph.extend(gof.toolbox.PreserveNames()) for feature in std_fgraph.features:
fgraph.extend(feature())
return fgraph, map(SymbolicOutput, updates) return fgraph, map(SymbolicOutput, updates)
std_fgraph.features = [gof.toolbox.PreserveNames()]
class UncomputableFeature(gof.Feature):
"""A feature that ensures the graph never contains any
uncomputable nodes. This check must be made at compile time
rather than runtime in order to make sure that NaN nodes are
not optimized out. It must be done as a Feature so that
the fgraph will continually check that optimizations have
not introduce any uncomputable nodes."""
def on_attach(self, fgraph):
for node in fgraph.nodes:
return self.on_import(fgraph, node)
def on_import(self, fgraph, node):
gof.op.raise_if_uncomputable(node)
std_fgraph.features.append(UncomputableFeature())
class AliasedMemoryError(Exception): class AliasedMemoryError(Exception):
"""Memory is aliased that should not be""" """Memory is aliased that should not be"""
pass pass
......
...@@ -70,6 +70,7 @@ from optdb import \ ...@@ -70,6 +70,7 @@ from optdb import \
EquilibriumDB, SequenceDB, ProxyDB EquilibriumDB, SequenceDB, ProxyDB
from toolbox import \ from toolbox import \
Feature, \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\ Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\
PrintListener, ReplacementDidntRemovedError PrintListener, ReplacementDidntRemovedError
......
...@@ -19,6 +19,7 @@ class InconsistencyError(Exception): ...@@ -19,6 +19,7 @@ class InconsistencyError(Exception):
""" """
pass pass
class MissingInputError(Exception): class MissingInputError(Exception):
""" """
A symbolic input needed to compute the outputs is missing. A symbolic input needed to compute the outputs is missing.
...@@ -26,7 +27,6 @@ class MissingInputError(Exception): ...@@ -26,7 +27,6 @@ class MissingInputError(Exception):
pass pass
class FunctionGraph(utils.object2): class FunctionGraph(utils.object2):
""" WRITEME """ WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and a A FunctionGraph represents a subgraph bound by a set of input variables and a
...@@ -46,46 +46,8 @@ class FunctionGraph(utils.object2): ...@@ -46,46 +46,8 @@ class FunctionGraph(utils.object2):
The .clients field combined with the .owner field and the Apply nodes' The .clients field combined with the .owner field and the Apply nodes'
.inputs field allows the graph to be traversed in both directions. .inputs field allows the graph to be traversed in both directions.
It can also be "extended" using function_graph.extend(some_object). See the It can also be "extended" using function_graph.extend(some_object).
toolbox and ext modules for common extensions. See toolbox.Feature for event types and documentation.
Features added with the`extend` function can handle the following events:
- feature.on_attach(function_graph)
Called by extend. The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
- feature.on_detach(function_graph)
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
- feature.on_import(function_graph, node)*
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
- feature.on_prune(function_graph, node)*
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
- feature.on_change_input(function_graph, node, i, r, new_r, [reason=None])*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(function_graph)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
- feature.on_setup_node(function_graph, node):
WRITEME
- feature.on_setup_variable(function_graph, variable):
WRITEME
Historically, the FunctionGraph was called an Env. Keep this in mind Historically, the FunctionGraph was called an Env. Keep this in mind
while reading out-of-date documentation, e-mail support threads, etc. while reading out-of-date documentation, e-mail support threads, etc.
...@@ -472,13 +434,19 @@ class FunctionGraph(utils.object2): ...@@ -472,13 +434,19 @@ class FunctionGraph(utils.object2):
# takes a sequence, and since this is a kind of container you # takes a sequence, and since this is a kind of container you
# would expect it to do similarly. # would expect it to do similarly.
def extend(self, feature): def extend(self, feature):
"""WRITEME """
Adds a feature to this function_graph. The feature may define one Adds a gof.toolbox.Feature to this function_graph
or more of the following methods: and triggers its on_attach callback
""" """
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
#it would be nice if we could require a specific class instead of
#a "workalike" so we could do actual error checking
#if not isinstance(feature, toolbox.Feature):
# raise TypeError("Expected gof.toolbox.Feature instance, got "+\
# str(type(feature)))
attach = getattr(feature, 'on_attach', None) attach = getattr(feature, 'on_attach', None)
if attach is not None: if attach is not None:
try: try:
......
...@@ -606,6 +606,56 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -606,6 +606,56 @@ class Op(utils.object2, PureOp, CLinkerOp):
rval.lazy = False rval.lazy = False
return rval return rval
class UncomputableOp(Op):
"""
An Op representing an expression that cannot be computed.
theano.function checks that the subgraph it implements
does not contain these ops, and that optimization does not
introduce any such ops.
theano.tensor.grad checks the graphs it returns to ensure
they do not contain these ops.
"""
def __init__(self, exc, msg=""):
"""
exc: the exception type to raise if a subgraph contains
this op.
msg: the message to include in the exception.
"""
self.exc = exc
self.msg = msg
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash((type(self)))
def __str__(self):
return "Uncomputable{%s,%s}"%(self.exc,self.msg)
def make_node(self,x):
return graph.Apply(self, [x], [x.type()] )
def perform(self, node, inputs, out_storage):
""" This should never be called"""
raise AssertionError("A BadGradOp should never be compiled, "+\
"and certainly not executed.")
#Note: essentially, this op should just be NaNs_like(inputs[0])
#but 0 * BadGradOp(x) + y optimizes to just y
#so until we develop a way of symbolically representing a variable
#that is always NaN and implement the logic for 0 * NaN = NaN, etc.
#the only way we can guarantee correctness of a theano function
#is to guarantee that its initial subgraph contained no BadGradOps
def raise_exc(self):
raise self.exc(self.msg)
def raise_if_uncomputable(node):
if node is not None:
if isinstance(node.op, UncomputableOp):
node.op.raise_exc()
def get_test_value(v): def get_test_value(v):
""" """
......
...@@ -19,7 +19,72 @@ class ReplacementDidntRemovedError(Exception): ...@@ -19,7 +19,72 @@ class ReplacementDidntRemovedError(Exception):
pass pass
class Bookkeeper: class Feature(object):
"""
Base class for FunctionGraph extensions.
A Feature is an object with several callbacks that are triggered
by various operations on FunctionGraphs. It can be used to enforce
graph properties at all stages of graph optimization.
See toolbox and ext modules for common extensions.
"""
def on_attach(self, function_graph):
"""
Called by FunctionGraph.extend, the method that attaches the feature
to the FunctionGraph. Since this is called after the FunctionGraph
is initially populated, this is where you should run checks on the
initial contents of the FunctionGraph.
The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
"""
def on_detach(self, function_graph):
"""
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
"""
def on_import(self, function_graph, node):
"""
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
Note: on_import is not called when the graph is created. If you
want to detect the first nodes to be implemented to the graph,
you should do this by implementing on_attach.
"""
def on_prune(self, function_graph, node):
"""
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
"""
def on_change_input(self, function_graph, node, i, r, new_r, reason=None):
"""
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
def orderings(self, function_graph):
"""
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
return {}
class Bookkeeper(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
...@@ -30,7 +95,7 @@ class Bookkeeper: ...@@ -30,7 +95,7 @@ class Bookkeeper:
self.on_prune(fgraph, node) self.on_prune(fgraph, node)
class History: class History(Feature):
def __init__(self): def __init__(self):
self.history = {} self.history = {}
...@@ -69,7 +134,7 @@ class History: ...@@ -69,7 +134,7 @@ class History:
self.history[fgraph] = h self.history[fgraph] = h
class Validator: class Validator(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for attr in ('validate', 'validate_time'): for attr in ('validate', 'validate_time'):
...@@ -224,7 +289,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -224,7 +289,7 @@ class NodeFinder(dict, Bookkeeper):
return all return all
class PrintListener(object): class PrintListener(Feature):
def __init__(self, active=True): def __init__(self, active=True):
self.active = active self.active = active
...@@ -251,7 +316,10 @@ class PrintListener(object): ...@@ -251,7 +316,10 @@ class PrintListener(object):
node, i, r, new_r) node, i, r, new_r)
class PreserveNames: class PreserveNames(Feature):
def on_change_input(self, fgraph, mode, i, r, new_r, reason=None): def on_change_input(self, fgraph, mode, i, r, new_r, reason=None):
if r.name is not None and new_r.name is None: if r.name is not None and new_r.name is None:
new_r.name = r.name new_r.name = r.name
"""Driver for gradient calculations.""" """Driver for gradient calculations."""
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron" __authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
__copyright__ = "(c) 2011, Universite de Montreal" __copyright__ = "(c) 2011, Universite de Montreal"
__license__ = "3-clause BSD License" __license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>" __contact__ = "theano-dev <theano-dev@googlegroups.com>"
...@@ -11,9 +11,9 @@ import __builtin__ ...@@ -11,9 +11,9 @@ import __builtin__
import logging import logging
import warnings import warnings
_logger = logging.getLogger('theano.gradient') _logger = logging.getLogger('theano.gradient')
import sys
import numpy # for numeric_grad import numpy # for numeric_grad
from collections import deque
import theano import theano
from theano.raise_op import Raise from theano.raise_op import Raise
...@@ -195,19 +195,42 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -195,19 +195,42 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
return gmap return gmap
def unimplemented_grad(op, x_pos, x): class GradNotImplementedOp(gof.op.UncomputableOp):
""" A BadGradOp representing a gradient that hasn't been implemented yet.
""" """
DO NOT USE. Remove this function after all usage of it has been
removed from theano.
def __init__(self, op, x_pos):
"""
op: A theano op whose grad is not implemented for some input
x_pos: An int, giving the index in the op's input list of
a variable for which the gradient is not implemented
(if op has unimplemented gradients for several inputs,
it must still return a separate UnimplementedGradOp for
each)
"""
assert isinstance(op, gof.Op)
assert isinstance(x_pos, int)
assert x_pos >= 0
super(GradNotImplementedOp,self).__init__(NotImplementedError,
"%s does not implement its gradient with respect to input %d" \
% (str(type(op)), x_pos))
def grad_not_implemented(op, x_pos, x):
"""
Return an un-computable symbolic variable of type `x.type`. Return an un-computable symbolic variable of type `x.type`.
If any function tries to compute this un-computable variable, an exception If any call to tensor.grad results in an expression containing this
(NotImplementedError) will be raised indicating that the gradient on the un-computable variable, an exception (NotImplementedError) will be
`x_pos`'th input of `op` has not been implemented. raised indicating that the gradient on the
`x_pos`'th input of `op` has not been implemented. Likewise if
any call to theano.function involves this variable.
""" """
msg = '%s.grad not implemented for input %i' % (op, x_pos)
return Raise(msg=msg)(x) return GradNotImplementedOp(op, x_pos)(x)
######################## ########################
# R Operator # R Operator
...@@ -528,6 +551,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -528,6 +551,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
and ret[-1].name is None: and ret[-1].name is None:
ret[-1].name = '(d%s/d%s)' % (cost.name, p.name) ret[-1].name = '(d%s/d%s)' % (cost.name, p.name)
# new_vars is meant to be a list of all variables created
# by this call to grad(), which will be visible to the caller
# after we return.
new_vars = gof.graph.ancestors(ret,
blockers=gof.graph.ancestors([cost]) + list(wrt))
map(gof.op.raise_if_uncomputable, [v.owner for v in new_vars])
return format_as(using_list, using_tuple, ret) return format_as(using_list, using_tuple, ret)
......
...@@ -3,6 +3,7 @@ from theano import Op, Apply ...@@ -3,6 +3,7 @@ from theano import Op, Apply
import theano.tensor as T import theano.tensor as T
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.sandbox.cuda import cuda_available, GpuOp from theano.sandbox.cuda import cuda_available, GpuOp
from theano.gradient import grad_not_implemented
if cuda_available: if cuda_available:
from theano.sandbox.cuda import CudaNdarrayType from theano.sandbox.cuda import CudaNdarrayType
...@@ -10,14 +11,6 @@ if cuda_available: ...@@ -10,14 +11,6 @@ if cuda_available:
from theano.sandbox.cuda.opt import register_opt as register_gpu_opt from theano.sandbox.cuda.opt import register_opt as register_gpu_opt
class BadOldCode(Exception):
"""
We create a specific Exception to be sure it does not get caught by
mistake.
"""
pass
class Images2Neibs(Op): class Images2Neibs(Op):
def __init__(self, mode='valid'): def __init__(self, mode='valid'):
""" """
...@@ -91,15 +84,8 @@ class Images2Neibs(Op): ...@@ -91,15 +84,8 @@ class Images2Neibs(Op):
def grad(self, inp, grads): def grad(self, inp, grads):
x, neib_shape, neib_step = inp x, neib_shape, neib_step = inp
gz, = grads gz, = grads
if self.mode in ['valid', 'ignore_borders']: return [ grad_not_implemented(self, i, ip) \
raise BadOldCode("The Images2Neibs grad is not implemented." for i, ip in enumerate(inp) ]
" It was in the past, but returned the wrong"
" answer!")
# This is the reverse of the op, not the grad!
return [neibs2images(gz, neib_shape, x.shape, mode=self.mode),
None, None]
else:
raise NotImplementedError()
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (5,)
......
...@@ -251,13 +251,36 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -251,13 +251,36 @@ class test_grad_sources_inputs(unittest.TestCase):
self.assertTrue(g[a1.inputs[0]] == 6) self.assertTrue(g[a1.inputs[0]] == 6)
self.assertTrue(g[a1.inputs[1]] == 11) self.assertTrue(g[a1.inputs[1]] == 11)
def test_unimplemented_grad(): def test_unimplemented_grad_func():
#tests that function compilation catches unimplemented grads in the graph
a = theano.tensor.vector() a = theano.tensor.vector()
b = theano.gradient.unimplemented_grad(theano.tensor.add, 1, a) b = theano.gradient.grad_not_implemented(theano.tensor.add, 0, a)
f = theano.function([a], b)
try: try:
f([1,2,3]) f = theano.function([a], b)
assert 0 assert 0
#Note: it's important that the NotImplementedGradOp is caught
#at COMPILATION time, not execution time.
#If the uncomputable variable is, for example, multiplied by 0,
#it could be optimized out of the final graph.
except NotImplementedError:
pass
def test_unimplemented_grad_grad():
#tests that unimplemented grads are caught in the grad method
class DummyOp(gof.Op):
def make_node(self, x):
return gof.Apply(self, [x], [x.type()])
def grad(self, inputs, output_grads):
return [ theano.gradient.grad_not_implemented(self, 0, inputs[0]) ]
a = theano.tensor.scalar()
b = DummyOp()(a)
try:
g = theano.gradient.grad(b,a)
assert False
except NotImplementedError: except NotImplementedError:
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论