提交 a1be5079 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

implemented correct handling of unimplemented gradients

上级 7ebae191
......@@ -13,7 +13,6 @@ from profiling import ProfileStats
from pfunc import pfunc
from numpy import any # to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, profile=None,
......@@ -192,6 +191,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode=mode,
accept_inplace=accept_inplace, name=name)
else:
#note: pfunc will also call orig_function-- orig_function is a choke point
# that all compilation must pass through
fn = pfunc(params=inputs,
outputs=outputs,
mode=mode,
......
......@@ -15,6 +15,7 @@ import numpy
import theano
from theano import gof
from theano.gof.python25 import partial
from theano.gradient import check_for_bad_grad
import mode as mode_module
from io import In, SymbolicInput, SymbolicInputKit, SymbolicOutput
......@@ -1336,6 +1337,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
t1 = time.time()
mode = mode_module.get_mode(mode)
check_for_bad_grad(outputs)
inputs = map(convert_function_input, inputs)
if outputs is not None:
if isinstance(outputs, (list, tuple)):
......
"""Driver for gradient calculations."""
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron"
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
__copyright__ = "(c) 2011, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
......@@ -11,9 +11,9 @@ import __builtin__
import logging
import warnings
_logger = logging.getLogger('theano.gradient')
import sys
import numpy # for numeric_grad
from collections import deque
import theano
from theano.raise_op import Raise
......@@ -194,20 +194,131 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
gmap[r] = g_r
return gmap
class BadGradOp(gof.Op):
"""
An Op representing a gradient that cannot be computed.
theano.tensor.grad checks the graphs it returns to ensure
they do not contain these ops.
theano.function also checks that the subgraph it implements
does not contain these ops.
"""
def __init__(self, exc, msg=""):
"""
exc: the exception type to raise if a subgraph contains
this op.
msg: the message to include in the exception.
"""
def unimplemented_grad(op, x_pos, x):
self.exc = exc
self.msg = msg
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash((type(self)))
def __str__(self):
return "BadGrad{%s,%s}"%(self.exc,self.msg)
def make_node(self,x):
return gof.Apply(self, [x], [x.type()] )
def perform(self, node, inputs, out_storage):
""" This should never be called"""
raise AssertionError("A BadGradOp should never be compiled, "+\
"and certainly not executed.")
#Note: essentially, this op should just be NaNs_like(inputs[0])
#but 0 * BadGradOp(x) + y optimizes to just y
#so until we develop a way of symbolically representing a variable
#that is always NaN and implement the logic for 0 * NaN = NaN, etc.
#the only way we can guarantee correctness of a theano function
#is to guarantee that its initial subgraph contained no BadGradOps
def raise_exc(self):
raise self.exc(self.msg)
class GradNotImplementedOp(BadGradOp):
""" A BadGradOp representing a gradient that hasn't been implemented yet.
"""
DO NOT USE. Remove this function after all usage of it has been
removed from theano.
def __init__(self, op, x_pos):
"""
op: A theano op whose grad is not implemented for some input
x_pos: An int, giving the index in the op's input list of
a variable for which the gradient is not implemented
(if op has unimplemented gradients for several inputs,
it must still return a separate UnimplementedGradOp for
each)
"""
assert isinstance(op, gof.Op)
assert isinstance(x_pos, int)
assert x_pos >= 0
super(GradNotImplementedOp,self).__init__(NotImplementedError,
"%s does not implement its gradient with respect to input %d" \
% (str(type(op)), x_pos))
def grad_not_implemented(op, x_pos, x):
"""
Return an un-computable symbolic variable of type `x.type`.
If any function tries to compute this un-computable variable, an exception
(NotImplementedError) will be raised indicating that the gradient on the
`x_pos`'th input of `op` has not been implemented.
If any call to tensor.grad results in an expression containing this
un-computable variable, an exception (NotImplementedError) will be
raised indicating that the gradient on the
`x_pos`'th input of `op` has not been implemented. Likewise if
any call to theano.function involves this variable.
"""
return GradNotImplementedOp(op, x_pos)(x)
def check_for_bad_grad( variables ):
"""
variables: A gof.Variable or list thereof
Raises an exception if any of the variables represents
an expression involving a BadGradOp
"""
msg = '%s.grad not implemented for input %i' % (op, x_pos)
return Raise(msg=msg)(x)
#implemented using a deque rather than recursion because python recursion
#limit is set low by default
if not (isinstance(variables, list) or \
isinstance(variables, gof.Variable)):
raise TypeError("Expected gof.Variable or list thereof, got "+\
str(type(variables)))
if not isinstance(variables,list):
variables = [ variables ]
vars_to_check = deque(variables)
already_checked = set([])
while True:
try:
var = vars_to_check.pop()
except IndexError:
break
if var not in already_checked:
already_checked.update([var])
assert isinstance(var, gof.Variable)
node = var.owner
if node is not None:
op = node.op
if isinstance(op, BadGradOp):
op.raise_exc()
vars_to_check.extendleft(node.inputs)
#end if node is not None
#end if not already_checked
#end while
########################
# R Operator
......@@ -528,6 +639,8 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
and ret[-1].name is None:
ret[-1].name = '(d%s/d%s)' % (cost.name, p.name)
check_for_bad_grad(ret)
return format_as(using_list, using_tuple, ret)
......
......@@ -251,13 +251,36 @@ class test_grad_sources_inputs(unittest.TestCase):
self.assertTrue(g[a1.inputs[0]] == 6)
self.assertTrue(g[a1.inputs[1]] == 11)
def test_unimplemented_grad():
def test_unimplemented_grad_func():
#tests that function compilation catches unimplemented grads in the graph
a = theano.tensor.vector()
b = theano.gradient.unimplemented_grad(theano.tensor.add, 1, a)
f = theano.function([a], b)
b = theano.gradient.grad_not_implemented(theano.tensor.add, 0, a)
try:
f([1,2,3])
f = theano.function([a], b)
assert 0
#Note: it's important that the NotImplementedGradOp is caught
#at COMPILATION time, not execution time.
#If the uncomputable variable is, for example, multiplied by 0,
#it could be optimized out of the final graph.
except NotImplementedError:
pass
def test_unimplemented_grad_grad():
#tests that unimplemented grads are caught in the grad method
class DummyOp(gof.Op):
def make_node(self, x):
return gof.Apply(self, [x], [x.type()])
def grad(self, inputs, output_grads):
return [ theano.gradient.grad_not_implemented(self, 0, inputs[0]) ]
a = theano.tensor.scalar()
b = DummyOp()(a)
try:
g = theano.gradient.grad(b,a)
assert False
except NotImplementedError:
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论