提交 f9dd5a84 authored 作者: James Bergstra's avatar James Bergstra

added unimplented_grad method to theano.gradient

上级 3dd4cf97
"""Driver for general gradient calculations.""" """Driver for gradient calculations."""
__authors__ = "James Bergstra"
__copyright__ = "(c) 2011, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import logging
_logger = logging.getLogger('theano.gradient')
import sys import sys
import gof #, gof.variable
import numpy #for numeric_grad import numpy #for numeric_grad
import gof #, gof.variable
from gof.python25 import all from gof.python25 import all
import gof.utils import gof.utils
import logging from raise_op import Raise
_logger = logging.getLogger('theano.gradient')
def warning(*msg): def warning(*msg):
_logger.warning('WARNING theano.gradient: '+' '.join(msg)) _logger.warning('WARNING theano.gradient: '+' '.join(msg))
def info(*msg): def info(*msg):
...@@ -106,4 +114,14 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -106,4 +114,14 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
gmap[r] = g_r gmap[r] = g_r
return gmap return gmap
def unimplemented_grad(op, x_pos, x):
"""
Return an un-computable symbolic variable of type `x.type`.
If any function tries to compute this un-computable variable, an exception
(NotImplementedError) will be raised indicating that the gradient on the
`x_pos`'th input of `op` has not been implemented.
"""
msg = '%s.grad not implemented for input %i'%(op, x_pos)
return Raise(msg=msg)(x)
"""Symbolic Op for raising an exception."""
__authors__ = "James Bergstra"
__copyright__ = "(c) 2011, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
import gof
class Raise(gof.Op):
"""Op whose perform() raises an exception.
"""
def __init__(self, msg="", exc=NotImplementedError):
"""
msg - the argument to the exception
exc - an exception class to raise in self.perform
"""
self.msg = msg
self.exc = exc
def __eq__(self, other):
# Note: the msg does not technically have to be in the hash and eq
# because it doesn't affect the return value.
return (type(self) == type(other)
and self.msg == other.msg
and self.exc == other.exc)
def __hash__(self):
return hash((type(self), self.msg, self.exc))
def __str__(self):
return "Raise{%s(%s)}"%(self.exc, self.msg)
def make_node(self, x):
return gof.Apply(self, [x], [x.type()])
def perform(self, node, inputs, out_storage):
raise self.exc(self.msg)
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# #
import unittest import unittest
import numpy import numpy
import theano
from theano import gof from theano import gof
from theano.gradient import * from theano.gradient import *
...@@ -250,6 +251,15 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -250,6 +251,15 @@ class test_grad_sources_inputs(unittest.TestCase):
self.assertTrue(g[a1.inputs[0]] == 6) self.assertTrue(g[a1.inputs[0]] == 6)
self.assertTrue(g[a1.inputs[1]] == 11) self.assertTrue(g[a1.inputs[1]] == 11)
def test_unimplemented_grad():
a = theano.tensor.vector()
b = theano.gradient.unimplemented_grad(theano.tensor.add, 1, a)
f = theano.function([a], b)
try:
f([1,2,3])
assert 0
except NotImplementedError:
pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论