removed type mismatch warning from gradient propagation algorithm. It was not a…

removed type mismatch warning from gradient propagation algorithm. It was not a good idea. Its ok for a gradient to have a different type from the original thing.
上级 0bfb7f8c
...@@ -9,14 +9,27 @@ import numpy #for numeric_grad ...@@ -9,14 +9,27 @@ import numpy #for numeric_grad
from gof.python25 import all from gof.python25 import all
import gof.utils import gof.utils
def warning(msg): import logging
# replace this with logger.warning when adding logging support _logger=logging.getLogger("theano.gradient")
print >> sys.stderr, 'WARNING', msg _logger.setLevel(logging.WARN)
def error(*args):
#sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
_logger.error("ERROR: "+' '.join(str(a) for a in args))
def warning(*args):
#sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
_logger.warning("WARNING: "+' '.join(str(a) for a in args))
def info(*args):
#sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger.info("INFO: "+' '.join(str(a) for a in args))
def debug(*args):
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger.debug("DEBUG: "+' '.join(str(a) for a in args))
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients' _msg_badlen = 'op.grad(...) returned wrong number of gradients'
def grad_sources_inputs(sources, graph_inputs, warn_type=True): def grad_sources_inputs(sources, graph_inputs):
""" """
A gradient source is a pair (``r``, ``g_r``), in which ``r`` is a `Variable`, and ``g_r`` is a A gradient source is a pair (``r``, ``g_r``), in which ``r`` is a `Variable`, and ``g_r`` is a
`Variable` that is a gradient wrt ``r``. `Variable` that is a gradient wrt ``r``.
...@@ -101,9 +114,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -101,9 +114,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
len(g_inputs), len(g_inputs),
len(node.inputs)) len(node.inputs))
for ii, (r, g_r) in enumerate(zip(node.inputs, g_inputs)): for ii, (r, g_r) in enumerate(zip(node.inputs, g_inputs)):
if warn_type:
if g_r and (getattr(r,'type',0) != getattr(g_r,'type', 1)):
warning('%s.grad returned a different type for input %i: %s vs. %s'%(node.op, ii, r, g_r))
if g_r and len(sources) == 1 and sources[0][0].name and r.name: if g_r and len(sources) == 1 and sources[0][0].name and r.name:
g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name) g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name)
if g_r is not None: if g_r is not None:
......
...@@ -10,9 +10,7 @@ from theano.gradient import * ...@@ -10,9 +10,7 @@ from theano.gradient import *
from theano import gradient from theano import gradient
def _grad_sources_inputs(*args): _grad_sources_inputs = grad_sources_inputs
# warn_type was introduced after this code, it complains throughout for nothing.
return grad_sources_inputs(warn_type=False, *args)
class test_grad_sources_inputs(unittest.TestCase): class test_grad_sources_inputs(unittest.TestCase):
def test_retNone1(self): def test_retNone1(self):
...@@ -150,7 +148,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -150,7 +148,7 @@ class test_grad_sources_inputs(unittest.TestCase):
return [1] return [1]
i = gof.generic() i = gof.generic()
a1 = O(self).make_node(i) a1 = O(self).make_node(i)
g = grad_sources_inputs([(a1.outputs[0], 1)], None, warn_type=False) g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[i] is 1) self.failUnless(g[i] is 1)
def test_some_None_igrads(self): def test_some_None_igrads(self):
...@@ -172,7 +170,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -172,7 +170,7 @@ class test_grad_sources_inputs(unittest.TestCase):
k = gof.generic() k = gof.generic()
a1 = O(self, True).make_node(i,j) a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(a1.outputs[1], k) a2 = O(self, True).make_node(a1.outputs[1], k)
g = grad_sources_inputs([(a2.outputs[0], 1)], None, warn_type=False) g = grad_sources_inputs([(a2.outputs[0], 1)], None)
self.failUnless(g[i] is 1 and j not in g and k not in g) self.failUnless(g[i] is 1 and j not in g and k not in g)
a1 = O(self, True).make_node(i,j) a1 = O(self, True).make_node(i,j)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论