提交 ee2f8a24 authored 作者: bergstra@tikuanyin's avatar bergstra@tikuanyin

added optional argument (warn_type) to gradient.grad_sources_inputs to disable…

added optional argument (warn_type) to gradient.grad_sources_inputs to disable gradient type checking
上级 3519e348
......@@ -2,16 +2,21 @@
__docformat__ = "restructuredtext en"
import sys
import gof #, gof.variable
import numpy #for numeric_grad
from gof.python25 import all
import gof.utils
def warning(msg):
# replace this with logger.warning when adding logging support
print >> sys.stderr, 'WARNING', msg
_msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients'
def grad_sources_inputs(sources, graph_inputs):
def grad_sources_inputs(sources, graph_inputs, warn_type=True):
"""
A gradient source is a pair (``r``, ``g_r``), in which ``r`` is a `Variable`, and ``g_r`` is a
`Variable` that is a gradient wrt ``r``.
......@@ -96,8 +101,9 @@ def grad_sources_inputs(sources, graph_inputs):
len(g_inputs),
len(node.inputs))
for ii, (r, g_r) in enumerate(zip(node.inputs, g_inputs)):
if g_r and (r.type != g_r.type):
print 'WARNING: %s.grad returned a different type for input %i: %s vs. %s'%(node.op, ii, r.type, g_r.type)
if warn_type:
if g_r and (getattr(r,'type',0) != getattr(g_r,'type', 1)):
warning('%s.grad returned a different type for input %i: %s vs. %s'%(node.op, ii, r, g_r))
if g_r and len(sources) == 1 and sources[0][0].name and r.name:
g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name)
if g_r is not None:
......
......@@ -9,6 +9,11 @@ from theano import gof
from theano.gradient import *
from theano import gradient
def _grad_sources_inputs(*args):
# warn_type was introduced after this code, it complains throughout for nothing.
return grad_sources_inputs(warn_type=False, *args)
class test_grad_sources_inputs(unittest.TestCase):
def test_retNone1(self):
"""Test that it is not ok to return None from op.grad()"""
......@@ -21,7 +26,7 @@ class test_grad_sources_inputs(unittest.TestCase):
pass
a = retNone().make_node()
try:
grad_sources_inputs([(a.out, 1)], None)
_grad_sources_inputs([(a.out, 1)], None)
except ValueError, e:
self.failUnless(e[0] is gradient._msg_retType)
return
......@@ -36,7 +41,7 @@ class test_grad_sources_inputs(unittest.TestCase):
return [None]
i = gof.generic()
a = retNone().make_node(i)
g = grad_sources_inputs([(a.out, 1)], None)
g = _grad_sources_inputs([(a.out, 1)], None)
self.failUnless(not i in g)
def test_wrong_rval_len1(self):
......@@ -51,10 +56,10 @@ class test_grad_sources_inputs(unittest.TestCase):
i = gof.generic()
j = gof.generic()
a1 = retNone().make_node(i)
g = grad_sources_inputs([(a1.out, 1)], None)
g = _grad_sources_inputs([(a1.out, 1)], None)
a2 = retNone().make_node(i,j)
try:
g = grad_sources_inputs([(a2.out, 1)], None)
g = _grad_sources_inputs([(a2.out, 1)], None)
except ValueError, e:
self.failUnless(e[0] is gradient._msg_badlen)
return
......@@ -74,7 +79,7 @@ class test_grad_sources_inputs(unittest.TestCase):
i = gof.generic()
a1 = retNone(self).make_node(i)
g = grad_sources_inputs([(a1.out, None)], None)
g = _grad_sources_inputs([(a1.out, None)], None)
def test_1in_1out(self):
"""Test grad is called correctly for a 1-to-1 op"""
......@@ -87,7 +92,7 @@ class test_grad_sources_inputs(unittest.TestCase):
def grad(self, (x, ), (gz, )):
return gval,
a1 = O().make_node()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval)
def test_1in_Nout(self):
......@@ -101,7 +106,7 @@ class test_grad_sources_inputs(unittest.TestCase):
def grad(self, (x, ), (gz1, gz2)):
return gval,
a1 = O().make_node()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval)
def test_Nin_1out(self):
"""Test grad is called correctly for a many-to-1 op"""
......@@ -115,7 +120,7 @@ class test_grad_sources_inputs(unittest.TestCase):
def grad(self, (x0,x1), (gz, )):
return (gval0, gval1)
a1 = O().make_node()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval0)
self.failUnless(g[a1.inputs[1]] is gval1)
def test_Nin_Nout(self):
......@@ -130,7 +135,7 @@ class test_grad_sources_inputs(unittest.TestCase):
def grad(self, (x0,x1), (gz0,gz1)):
return gval0, gval1
a1 = O().make_node()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval0)
self.failUnless(g[a1.inputs[1]] is gval1)
def test_some_None_ograds(self):
......@@ -145,7 +150,7 @@ class test_grad_sources_inputs(unittest.TestCase):
return [1]
i = gof.generic()
a1 = O(self).make_node(i)
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
g = grad_sources_inputs([(a1.outputs[0], 1)], None, warn_type=False)
self.failUnless(g[i] is 1)
def test_some_None_igrads(self):
......@@ -167,12 +172,12 @@ class test_grad_sources_inputs(unittest.TestCase):
k = gof.generic()
a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(a1.outputs[1], k)
g = grad_sources_inputs([(a2.outputs[0], 1)], None)
g = grad_sources_inputs([(a2.outputs[0], 1)], None, warn_type=False)
self.failUnless(g[i] is 1 and j not in g and k not in g)
a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(k, a1.outputs[1])
g = grad_sources_inputs([(a2.outputs[0], 1)], None)
g = _grad_sources_inputs([(a2.outputs[0], 1)], None)
self.failUnless(g[k] is 1 and i not in g and j not in g)
def test_inputs(self):
......@@ -197,7 +202,7 @@ class test_grad_sources_inputs(unittest.TestCase):
k = gof.generic()
a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(k,a1.outputs[1])
g = grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
(a1.outputs[0], 3), (a1.outputs[0], 3)], a1.outputs)
self.failUnless(g[a2.inputs[0]] == 1)
self.failUnless(g[a2.inputs[1]] == 5)
......@@ -228,7 +233,7 @@ class test_grad_sources_inputs(unittest.TestCase):
k = gof.generic()
a1 = O(self,True).make_node(i,j)
a2 = O(self,True).make_node(k,a1.outputs[1])
g = grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
(a1.outputs[0], 3), (a1.outputs[0], 3)], None)
self.failUnless(g[a2.inputs[0]] == 1)
self.failUnless(g[a2.inputs[1]] == 5)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论