提交 8b465d41 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

removed test enforcing specific traversal behavior when an op returns

None
上级 0bdb4850
...@@ -131,38 +131,11 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -131,38 +131,11 @@ class test_grad_sources_inputs(unittest.TestCase):
outputs = [theano.tensor.matrix(),theano.tensor.matrix()] outputs = [theano.tensor.matrix(),theano.tensor.matrix()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inputs, g_out): def grad(self, inputs, g_out):
return [1] return [one]
i = theano.tensor.matrix() i = theano.tensor.matrix()
a1 = O(self).make_node(i) a1 = O(self).make_node(i)
g = grad_sources_inputs([(a1.outputs[0], one)], None, warn_type=False) g = grad_sources_inputs([(a1.outputs[0], one)], None, warn_type=False)
self.assertTrue(g[i] is 1) self.assertTrue(g[i] is one)
def test_some_None_igrads(self):
"""Test that traversal works properly when an op return some None"""
class O(gof.op.Op):
def __init__(self, tst, grad_ok):
self.tst = tst
self.grad_ok = grad_ok
def make_node(self, *inputs):
outputs = [theano.tensor.matrix(),theano.tensor.matrix()]
return gof.Apply(self, inputs, outputs)
def grad(self, inputs, g_out):
if not self.grad_ok:
self.tst.fail()
else:
return [1, None]
i = theano.tensor.matrix()
j = theano.tensor.matrix()
k = theano.tensor.matrix()
a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(a1.outputs[1], k)
g = grad_sources_inputs([(a2.outputs[0], one)], None, warn_type=False)
self.assertTrue(g[i] is 1 and j not in g and k not in g)
a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(k, a1.outputs[1])
g = _grad_sources_inputs([(a2.outputs[0], one)], None)
self.assertTrue(g[k] is 1 and i not in g and j not in g)
def test_inputs(self): def test_inputs(self):
"""Test that passing inputs shortens the traversal""" """Test that passing inputs shortens the traversal"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论