提交 27b2330b authored 作者: Ian Goodfellow's avatar Ian Goodfellow

re-enabled grad_sources_inputs tests

made grad_sources_inputs accept None for inputs
上级 46c420be
...@@ -607,6 +607,11 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'): ...@@ -607,6 +607,11 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
outputs, output_grads = zip(*sources) outputs, output_grads = zip(*sources)
if graph_inputs is None:
graph_inputs = gof.graph.inputs(outputs)
wrt = graph_inputs wrt = graph_inputs
......
...@@ -6,256 +6,254 @@ import unittest ...@@ -6,256 +6,254 @@ import unittest
import theano import theano
from theano import gof from theano import gof
#from theano.gradient import grad_sources_inputs from theano.gradient import grad_sources_inputs
from theano import gradient from theano import gradient
from theano.tensor.nnet.Conv3D import conv3D from theano.tensor.nnet.Conv3D import conv3D
from theano import config from theano import config
#def _grad_sources_inputs(*args): def _grad_sources_inputs(*args):
# warn_type was introduced after this code, it complains throughout for nothing. # warn_type was introduced after this code, it complains throughout for nothing.
# return grad_sources_inputs(warn_type=False, *args) return grad_sources_inputs(warn_type=False, *args)
if 0: class test_grad_sources_inputs(unittest.TestCase):
#most of these tests are no longer relevant now that grad_sources_inputs is gone def test_retNone1(self):
#also, some of our policies about what is allowed or not have changed """Test that it is not ok to return None from op.grad()"""
#nonetheless, it may be a good idea to resurrect some of these tests and write class retNone(gof.op.Op):
#them in terms of tensor.grad directly def make_node(self):
class test_grad_sources_inputs(unittest.TestCase): inputs = [gof.generic()]
def test_retNone1(self): outputs = [gof.generic()]
"""Test that it is not ok to return None from op.grad()""" return gof.Apply(self, inputs, outputs)
class retNone(gof.op.Op): def grad(self, inp, grads):
def make_node(self): x, = inp
inputs = [gof.generic()] gz, = grads
outputs = [gof.generic()] pass
return gof.Apply(self, inputs, outputs) a = retNone().make_node()
def grad(self, inp, grads): try:
x, = inp _grad_sources_inputs([(a.out, 1)], None)
gz, = grads except ValueError, e:
pass self.assertTrue(e[0] is gradient._msg_retType)
a = retNone().make_node() return
try: self.fail()
_grad_sources_inputs([(a.out, 1)], None) def test_retNone1_b(self):
except ValueError, e: """Test that it is ok to return [None] from op.grad()"""
self.assertTrue(e[0] is gradient._msg_retType) class retNone(gof.op.Op):
return def make_node(self, *inputs):
self.fail() outputs = [gof.generic()]
def test_retNone1_b(self): return gof.Apply(self, inputs, outputs)
"""Test that it is ok to return [None] from op.grad()""" def grad(self, inp, grads):
class retNone(gof.op.Op): return [None]
def make_node(self, *inputs): i = gof.generic()
outputs = [gof.generic()] a = retNone().make_node(i)
return gof.Apply(self, inputs, outputs) g = _grad_sources_inputs([(a.out, 1)], None)
def grad(self, inp, grads): self.assertTrue(not i in g)
return [None]
i = gof.generic() def test_wrong_rval_len1(self):
a = retNone().make_node(i) """Test that it is not ok to return the wrong number of gradients"""
g = _grad_sources_inputs([(a.out, 1)], None) class retNone(gof.op.Op):
self.assertTrue(not i in g) def make_node(self, *inputs):
outputs = [gof.generic()]
def test_wrong_rval_len1(self): return gof.Apply(self, inputs, outputs)
"""Test that it is not ok to return the wrong number of gradients""" def grad(self, inputs, grads):
class retNone(gof.op.Op): return [None]
def make_node(self, *inputs):
outputs = [gof.generic()] i = gof.generic()
return gof.Apply(self, inputs, outputs) j = gof.generic()
def grad(self, inputs, grads): a1 = retNone().make_node(i)
return [None] g = _grad_sources_inputs([(a1.out, 1)], None)
a2 = retNone().make_node(i,j)
i = gof.generic() try:
j = gof.generic() g = _grad_sources_inputs([(a2.out, 1)], None)
a1 = retNone().make_node(i) except ValueError, e:
g = _grad_sources_inputs([(a1.out, 1)], None) self.assertTrue(e[0] is gradient._msg_badlen)
a2 = retNone().make_node(i,j) return
try: self.fail()
g = _grad_sources_inputs([(a2.out, 1)], None)
except ValueError, e:
self.assertTrue(e[0] is gradient._msg_badlen) def test_stop_on_all_none(self):
return """Test that op.grad() is not called when output grads are all None"""
self.fail() class retNone(gof.op.Op):
def __init__(self, tst):
self.tst = tst
def test_stop_on_all_none(self): def make_node(self, *inputs):
"""Test that op.grad() is not called when output grads are all None""" outputs = [gof.generic()]
class retNone(gof.op.Op): return gof.Apply(self, inputs, outputs)
def __init__(self, tst): def grad(self, inputs, grads):
self.tst = tst self.tst.fail()
def make_node(self, *inputs):
outputs = [gof.generic()] i = gof.generic()
return gof.Apply(self, inputs, outputs) a1 = retNone(self).make_node(i)
def grad(self, inputs, grads): g = _grad_sources_inputs([(a1.out, None)], None)
def test_1in_1out(self):
"""Test grad is called correctly for a 1-to-1 op"""
gval = gof.generic()
class O(gof.op.Op):
def make_node(self):
inputs = [gof.generic()]
outputs = [gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads):
return gval,
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval)
def test_1in_Nout(self):
"""Test grad is called correctly for a 1-to-many op"""
gval = gof.generic()
class O(gof.op.Op):
def make_node(self):
inputs = [gof.generic()]
outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads):
x, = inp
gz1, gz2 = grads
return gval,
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval)
def test_Nin_1out(self):
"""Test grad is called correctly for a many-to-1 op"""
gval0 = gof.generic()
gval1 = gof.generic()
class O(gof.op.Op):
def make_node(self):
inputs = [gof.generic(),gof.generic()]
outputs = [gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads):
x0, x1 = inp
gz, = grads
return (gval0, gval1)
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval0)
self.assertTrue(g[a1.inputs[1]] is gval1)
def test_Nin_Nout(self):
"""Test grad is called correctly for a many-to-many op"""
gval0 = gof.generic()
gval1 = gof.generic()
class O(gof.op.Op):
def make_node(self):
inputs = [gof.generic(),gof.generic()]
outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads):
return gval0, gval1
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval0)
self.assertTrue(g[a1.inputs[1]] is gval1)
def test_some_None_ograds(self):
"""Test grad is called when some output gradients are None"""
class O(gof.op.Op):
def __init__(self, tst):
self.tst = tst
def make_node(self, *inputs):
outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inputs, g_out):
return [1]
i = gof.generic()
a1 = O(self).make_node(i)
g = grad_sources_inputs([(a1.outputs[0], 1)], None, warn_type=False)
self.assertTrue(g[i] is 1)
def test_some_None_igrads(self):
"""Test that traversal works properly when an op return some None"""
class O(gof.op.Op):
def __init__(self, tst, grad_ok):
self.tst = tst
self.grad_ok = grad_ok
def make_node(self, *inputs):
outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inputs, g_out):
if not self.grad_ok:
self.tst.fail() self.tst.fail()
else:
return [1, None]
i = gof.generic()
j = gof.generic()
k = gof.generic()
a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(a1.outputs[1], k)
g = grad_sources_inputs([(a2.outputs[0], 1)], None, warn_type=False)
self.assertTrue(g[i] is 1 and j not in g and k not in g)
i = gof.generic() a1 = O(self, True).make_node(i,j)
a1 = retNone(self).make_node(i) a2 = O(self, True).make_node(k, a1.outputs[1])
g = _grad_sources_inputs([(a1.out, None)], None) g = _grad_sources_inputs([(a2.outputs[0], 1)], None)
self.assertTrue(g[k] is 1 and i not in g and j not in g)
def test_1in_1out(self):
"""Test grad is called correctly for a 1-to-1 op""" def test_inputs(self):
gval = gof.generic() """Test that passing inputs shortens the traversal"""
class O(gof.op.Op): class O(gof.op.Op):
def make_node(self): def __init__(self, tst, grad_ok):
inputs = [gof.generic()] self.tst = tst
outputs = [gof.generic()] self.grad_ok = grad_ok
return gof.Apply(self, inputs, outputs) def make_node(self, *inputs):
def grad(self, inp, grads): outputs = [gof.generic(),gof.generic()]
return gval, return gof.Apply(self, inputs, outputs)
a1 = O().make_node() def grad(self, inputs, grads):
g = _grad_sources_inputs([(a1.outputs[0], 1)], None) g0, g1 = grads
self.assertTrue(g[a1.inputs[0]] is gval) if not self.grad_ok:
self.tst.fail()
def test_1in_Nout(self): else:
"""Test grad is called correctly for a 1-to-many op""" if g1:
gval = gof.generic() return [g0, g0+g1]
class O(gof.op.Op):
def make_node(self):
inputs = [gof.generic()]
outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads):
x, = inp
gz1, gz2 = grads
return gval,
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval)
def test_Nin_1out(self):
"""Test grad is called correctly for a many-to-1 op"""
gval0 = gof.generic()
gval1 = gof.generic()
class O(gof.op.Op):
def make_node(self):
inputs = [gof.generic(),gof.generic()]
outputs = [gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads):
x0, x1 = inp
gz, = grads
return (gval0, gval1)
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval0)
self.assertTrue(g[a1.inputs[1]] is gval1)
def test_Nin_Nout(self):
"""Test grad is called correctly for a many-to-many op"""
gval0 = gof.generic()
gval1 = gof.generic()
class O(gof.op.Op):
def make_node(self):
inputs = [gof.generic(),gof.generic()]
outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads):
return gval0, gval1
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval0)
self.assertTrue(g[a1.inputs[1]] is gval1)
def test_some_None_ograds(self):
"""Test grad is called when some output gradients are None"""
class O(gof.op.Op):
def __init__(self, tst):
self.tst = tst
def make_node(self, *inputs):
outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inputs, g_out):
return [1]
i = gof.generic()
a1 = O(self).make_node(i)
g = grad_sources_inputs([(a1.outputs[0], 1)], None, warn_type=False)
self.assertTrue(g[i] is 1)
def test_some_None_igrads(self):
"""Test that traversal works properly when an op return some None"""
class O(gof.op.Op):
def __init__(self, tst, grad_ok):
self.tst = tst
self.grad_ok = grad_ok
def make_node(self, *inputs):
outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inputs, g_out):
if not self.grad_ok:
self.tst.fail()
else:
return [1, None]
i = gof.generic()
j = gof.generic()
k = gof.generic()
a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(a1.outputs[1], k)
g = grad_sources_inputs([(a2.outputs[0], 1)], None, warn_type=False)
self.assertTrue(g[i] is 1 and j not in g and k not in g)
a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(k, a1.outputs[1])
g = _grad_sources_inputs([(a2.outputs[0], 1)], None)
self.assertTrue(g[k] is 1 and i not in g and j not in g)
def test_inputs(self):
"""Test that passing inputs shortens the traversal"""
class O(gof.op.Op):
def __init__(self, tst, grad_ok):
self.tst = tst
self.grad_ok = grad_ok
def make_node(self, *inputs):
outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs)
def grad(self, inputs, grads):
g0, g1 = grads
if not self.grad_ok:
self.tst.fail()
else: else:
if g1: return [g0, g0]
return [g0, g0+g1] i = gof.generic()
else: j = gof.generic()
return [g0, g0] k = gof.generic()
i = gof.generic() a1 = O(self, True).make_node(i,j)
j = gof.generic() a2 = O(self, True).make_node(k,a1.outputs[1])
k = gof.generic() g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
a1 = O(self, True).make_node(i,j) (a1.outputs[0], 3), (a1.outputs[0], 3)], a1.outputs)
a2 = O(self, True).make_node(k,a1.outputs[1]) self.assertTrue(g[a2.inputs[0]] == 1)
g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4), self.assertTrue(g[a2.inputs[1]] == 5)
(a1.outputs[0], 3), (a1.outputs[0], 3)], a1.outputs) self.assertTrue(g[a1.outputs[0]] == 6)
self.assertTrue(g[a2.inputs[0]] == 1) self.assertTrue(g[a1.outputs[1]] == 5)
self.assertTrue(g[a2.inputs[1]] == 5) self.assertTrue(a1.inputs[0] not in g)
self.assertTrue(g[a1.outputs[0]] == 6) self.assertTrue(a1.inputs[1] not in g)
self.assertTrue(g[a1.outputs[1]] == 5)
self.assertTrue(a1.inputs[0] not in g) def test_multiple_sources(self):
self.assertTrue(a1.inputs[1] not in g) """Test that passing multiple sources works"""
class O(gof.op.Op):
def test_multiple_sources(self): def __init__(self, tst, grad_ok):
"""Test that passing multiple sources works""" self.tst = tst
class O(gof.op.Op): self.grad_ok = grad_ok
def __init__(self, tst, grad_ok): def make_node(self, *inputs):
self.tst = tst outputs = [gof.generic(),gof.generic()]
self.grad_ok = grad_ok return gof.Apply(self, inputs, outputs)
def make_node(self, *inputs): def grad(self, inputs, grads):
outputs = [gof.generic(),gof.generic()] g0, g1 = grads
return gof.Apply(self, inputs, outputs) if not self.grad_ok:
def grad(self, inputs, grads): self.tst.fail()
g0, g1 = grads else:
if not self.grad_ok: if g1:
self.tst.fail() return [g0, g0+g1]
else: else:
if g1: return [g0, g0]
return [g0, g0+g1] i = gof.generic()
else: j = gof.generic()
return [g0, g0] k = gof.generic()
i = gof.generic() a1 = O(self,True).make_node(i,j)
j = gof.generic() a2 = O(self,True).make_node(k,a1.outputs[1])
k = gof.generic() g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
a1 = O(self,True).make_node(i,j) (a1.outputs[0], 3), (a1.outputs[0], 3)], None)
a2 = O(self,True).make_node(k,a1.outputs[1]) self.assertTrue(g[a2.inputs[0]] == 1)
g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4), self.assertTrue(g[a2.inputs[1]] == 5)
(a1.outputs[0], 3), (a1.outputs[0], 3)], None) self.assertTrue(g[a1.outputs[0]] == 6)
self.assertTrue(g[a2.inputs[0]] == 1) self.assertTrue(g[a1.outputs[1]] == 5)
self.assertTrue(g[a2.inputs[1]] == 5) self.assertTrue(g[a1.inputs[0]] == 6)
self.assertTrue(g[a1.outputs[0]] == 6) self.assertTrue(g[a1.inputs[1]] == 11)
self.assertTrue(g[a1.outputs[1]] == 5)
self.assertTrue(g[a1.inputs[0]] == 6)
self.assertTrue(g[a1.inputs[1]] == 11)
def test_unimplemented_grad_func(): def test_unimplemented_grad_func():
#tests that function compilation catches unimplemented grads in the graph #tests that function compilation catches unimplemented grads in the graph
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论