提交 ee2f8a24 authored 作者: bergstra@tikuanyin's avatar bergstra@tikuanyin

added optional argument (warn_type) to gradient.grad_sources_inputs to disable…

added optional argument (warn_type) to gradient.grad_sources_inputs to disable gradient type checking
上级 3519e348
...@@ -2,16 +2,21 @@ ...@@ -2,16 +2,21 @@
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import sys
import gof #, gof.variable import gof #, gof.variable
import numpy #for numeric_grad import numpy #for numeric_grad
from gof.python25 import all from gof.python25 import all
import gof.utils import gof.utils
def warning(msg):
# replace this with logger.warning when adding logging support
print >> sys.stderr, 'WARNING', msg
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients' _msg_badlen = 'op.grad(...) returned wrong number of gradients'
def grad_sources_inputs(sources, graph_inputs): def grad_sources_inputs(sources, graph_inputs, warn_type=True):
""" """
A gradient source is a pair (``r``, ``g_r``), in which ``r`` is a `Variable`, and ``g_r`` is a A gradient source is a pair (``r``, ``g_r``), in which ``r`` is a `Variable`, and ``g_r`` is a
`Variable` that is a gradient wrt ``r``. `Variable` that is a gradient wrt ``r``.
...@@ -96,8 +101,9 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -96,8 +101,9 @@ def grad_sources_inputs(sources, graph_inputs):
len(g_inputs), len(g_inputs),
len(node.inputs)) len(node.inputs))
for ii, (r, g_r) in enumerate(zip(node.inputs, g_inputs)): for ii, (r, g_r) in enumerate(zip(node.inputs, g_inputs)):
if g_r and (r.type != g_r.type): if warn_type:
print 'WARNING: %s.grad returned a different type for input %i: %s vs. %s'%(node.op, ii, r.type, g_r.type) if g_r and (getattr(r,'type',0) != getattr(g_r,'type', 1)):
warning('%s.grad returned a different type for input %i: %s vs. %s'%(node.op, ii, r, g_r))
if g_r and len(sources) == 1 and sources[0][0].name and r.name: if g_r and len(sources) == 1 and sources[0][0].name and r.name:
g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name) g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name)
if g_r is not None: if g_r is not None:
......
...@@ -9,6 +9,11 @@ from theano import gof ...@@ -9,6 +9,11 @@ from theano import gof
from theano.gradient import * from theano.gradient import *
from theano import gradient from theano import gradient
def _grad_sources_inputs(*args):
# warn_type was introduced after this code, it complains throughout for nothing.
return grad_sources_inputs(warn_type=False, *args)
class test_grad_sources_inputs(unittest.TestCase): class test_grad_sources_inputs(unittest.TestCase):
def test_retNone1(self): def test_retNone1(self):
"""Test that it is not ok to return None from op.grad()""" """Test that it is not ok to return None from op.grad()"""
...@@ -21,7 +26,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -21,7 +26,7 @@ class test_grad_sources_inputs(unittest.TestCase):
pass pass
a = retNone().make_node() a = retNone().make_node()
try: try:
grad_sources_inputs([(a.out, 1)], None) _grad_sources_inputs([(a.out, 1)], None)
except ValueError, e: except ValueError, e:
self.failUnless(e[0] is gradient._msg_retType) self.failUnless(e[0] is gradient._msg_retType)
return return
...@@ -36,7 +41,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -36,7 +41,7 @@ class test_grad_sources_inputs(unittest.TestCase):
return [None] return [None]
i = gof.generic() i = gof.generic()
a = retNone().make_node(i) a = retNone().make_node(i)
g = grad_sources_inputs([(a.out, 1)], None) g = _grad_sources_inputs([(a.out, 1)], None)
self.failUnless(not i in g) self.failUnless(not i in g)
def test_wrong_rval_len1(self): def test_wrong_rval_len1(self):
...@@ -51,10 +56,10 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -51,10 +56,10 @@ class test_grad_sources_inputs(unittest.TestCase):
i = gof.generic() i = gof.generic()
j = gof.generic() j = gof.generic()
a1 = retNone().make_node(i) a1 = retNone().make_node(i)
g = grad_sources_inputs([(a1.out, 1)], None) g = _grad_sources_inputs([(a1.out, 1)], None)
a2 = retNone().make_node(i,j) a2 = retNone().make_node(i,j)
try: try:
g = grad_sources_inputs([(a2.out, 1)], None) g = _grad_sources_inputs([(a2.out, 1)], None)
except ValueError, e: except ValueError, e:
self.failUnless(e[0] is gradient._msg_badlen) self.failUnless(e[0] is gradient._msg_badlen)
return return
...@@ -74,7 +79,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -74,7 +79,7 @@ class test_grad_sources_inputs(unittest.TestCase):
i = gof.generic() i = gof.generic()
a1 = retNone(self).make_node(i) a1 = retNone(self).make_node(i)
g = grad_sources_inputs([(a1.out, None)], None) g = _grad_sources_inputs([(a1.out, None)], None)
def test_1in_1out(self): def test_1in_1out(self):
"""Test grad is called correctly for a 1-to-1 op""" """Test grad is called correctly for a 1-to-1 op"""
...@@ -87,7 +92,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -87,7 +92,7 @@ class test_grad_sources_inputs(unittest.TestCase):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gval, return gval,
a1 = O().make_node() a1 = O().make_node()
g = grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval) self.failUnless(g[a1.inputs[0]] is gval)
def test_1in_Nout(self): def test_1in_Nout(self):
...@@ -101,7 +106,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -101,7 +106,7 @@ class test_grad_sources_inputs(unittest.TestCase):
def grad(self, (x, ), (gz1, gz2)): def grad(self, (x, ), (gz1, gz2)):
return gval, return gval,
a1 = O().make_node() a1 = O().make_node()
g = grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval) self.failUnless(g[a1.inputs[0]] is gval)
def test_Nin_1out(self): def test_Nin_1out(self):
"""Test grad is called correctly for a many-to-1 op""" """Test grad is called correctly for a many-to-1 op"""
...@@ -115,7 +120,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -115,7 +120,7 @@ class test_grad_sources_inputs(unittest.TestCase):
def grad(self, (x0,x1), (gz, )): def grad(self, (x0,x1), (gz, )):
return (gval0, gval1) return (gval0, gval1)
a1 = O().make_node() a1 = O().make_node()
g = grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval0) self.failUnless(g[a1.inputs[0]] is gval0)
self.failUnless(g[a1.inputs[1]] is gval1) self.failUnless(g[a1.inputs[1]] is gval1)
def test_Nin_Nout(self): def test_Nin_Nout(self):
...@@ -130,7 +135,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -130,7 +135,7 @@ class test_grad_sources_inputs(unittest.TestCase):
def grad(self, (x0,x1), (gz0,gz1)): def grad(self, (x0,x1), (gz0,gz1)):
return gval0, gval1 return gval0, gval1
a1 = O().make_node() a1 = O().make_node()
g = grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval0) self.failUnless(g[a1.inputs[0]] is gval0)
self.failUnless(g[a1.inputs[1]] is gval1) self.failUnless(g[a1.inputs[1]] is gval1)
def test_some_None_ograds(self): def test_some_None_ograds(self):
...@@ -145,7 +150,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -145,7 +150,7 @@ class test_grad_sources_inputs(unittest.TestCase):
return [1] return [1]
i = gof.generic() i = gof.generic()
a1 = O(self).make_node(i) a1 = O(self).make_node(i)
g = grad_sources_inputs([(a1.outputs[0], 1)], None) g = grad_sources_inputs([(a1.outputs[0], 1)], None, warn_type=False)
self.failUnless(g[i] is 1) self.failUnless(g[i] is 1)
def test_some_None_igrads(self): def test_some_None_igrads(self):
...@@ -167,12 +172,12 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -167,12 +172,12 @@ class test_grad_sources_inputs(unittest.TestCase):
k = gof.generic() k = gof.generic()
a1 = O(self, True).make_node(i,j) a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(a1.outputs[1], k) a2 = O(self, True).make_node(a1.outputs[1], k)
g = grad_sources_inputs([(a2.outputs[0], 1)], None) g = grad_sources_inputs([(a2.outputs[0], 1)], None, warn_type=False)
self.failUnless(g[i] is 1 and j not in g and k not in g) self.failUnless(g[i] is 1 and j not in g and k not in g)
a1 = O(self, True).make_node(i,j) a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(k, a1.outputs[1]) a2 = O(self, True).make_node(k, a1.outputs[1])
g = grad_sources_inputs([(a2.outputs[0], 1)], None) g = _grad_sources_inputs([(a2.outputs[0], 1)], None)
self.failUnless(g[k] is 1 and i not in g and j not in g) self.failUnless(g[k] is 1 and i not in g and j not in g)
def test_inputs(self): def test_inputs(self):
...@@ -197,7 +202,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -197,7 +202,7 @@ class test_grad_sources_inputs(unittest.TestCase):
k = gof.generic() k = gof.generic()
a1 = O(self, True).make_node(i,j) a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(k,a1.outputs[1]) a2 = O(self, True).make_node(k,a1.outputs[1])
g = grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4), g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
(a1.outputs[0], 3), (a1.outputs[0], 3)], a1.outputs) (a1.outputs[0], 3), (a1.outputs[0], 3)], a1.outputs)
self.failUnless(g[a2.inputs[0]] == 1) self.failUnless(g[a2.inputs[0]] == 1)
self.failUnless(g[a2.inputs[1]] == 5) self.failUnless(g[a2.inputs[1]] == 5)
...@@ -228,7 +233,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -228,7 +233,7 @@ class test_grad_sources_inputs(unittest.TestCase):
k = gof.generic() k = gof.generic()
a1 = O(self,True).make_node(i,j) a1 = O(self,True).make_node(i,j)
a2 = O(self,True).make_node(k,a1.outputs[1]) a2 = O(self,True).make_node(k,a1.outputs[1])
g = grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4), g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
(a1.outputs[0], 3), (a1.outputs[0], 3)], None) (a1.outputs[0], 3), (a1.outputs[0], 3)], None)
self.failUnless(g[a2.inputs[0]] == 1) self.failUnless(g[a2.inputs[0]] == 1)
self.failUnless(g[a2.inputs[1]] == 5) self.failUnless(g[a2.inputs[1]] == 5)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论