提交 27b2330b authored 作者: Ian Goodfellow's avatar Ian Goodfellow

re-enabled grad_sources_inputs tests

made grad_sources_inputs accept None for inputs
上级 46c420be
...@@ -607,6 +607,11 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'): ...@@ -607,6 +607,11 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
outputs, output_grads = zip(*sources) outputs, output_grads = zip(*sources)
if graph_inputs is None:
graph_inputs = gof.graph.inputs(outputs)
wrt = graph_inputs wrt = graph_inputs
......
...@@ -6,22 +6,17 @@ import unittest ...@@ -6,22 +6,17 @@ import unittest
import theano import theano
from theano import gof from theano import gof
#from theano.gradient import grad_sources_inputs from theano.gradient import grad_sources_inputs
from theano import gradient from theano import gradient
from theano.tensor.nnet.Conv3D import conv3D from theano.tensor.nnet.Conv3D import conv3D
from theano import config from theano import config
#def _grad_sources_inputs(*args): def _grad_sources_inputs(*args):
# warn_type was introduced after this code, it complains throughout for nothing. # warn_type was introduced after this code, it complains throughout for nothing.
# return grad_sources_inputs(warn_type=False, *args) return grad_sources_inputs(warn_type=False, *args)
if 0: class test_grad_sources_inputs(unittest.TestCase):
#most of these tests are no longer relevant now that grad_sources_inputs is gone
#also, some of our policies about what is allowed or not have changed
#nonetheless, it may be a good idea to resurrect some of these tests and write
#them in terms of tensor.grad directly
class test_grad_sources_inputs(unittest.TestCase):
def test_retNone1(self): def test_retNone1(self):
"""Test that it is not ok to return None from op.grad()""" """Test that it is not ok to return None from op.grad()"""
class retNone(gof.op.Op): class retNone(gof.op.Op):
...@@ -119,6 +114,7 @@ if 0: ...@@ -119,6 +114,7 @@ if 0:
a1 = O().make_node() a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval) self.assertTrue(g[a1.inputs[0]] is gval)
def test_Nin_1out(self): def test_Nin_1out(self):
"""Test grad is called correctly for a many-to-1 op""" """Test grad is called correctly for a many-to-1 op"""
gval0 = gof.generic() gval0 = gof.generic()
...@@ -136,6 +132,7 @@ if 0: ...@@ -136,6 +132,7 @@ if 0:
g = _grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval0) self.assertTrue(g[a1.inputs[0]] is gval0)
self.assertTrue(g[a1.inputs[1]] is gval1) self.assertTrue(g[a1.inputs[1]] is gval1)
def test_Nin_Nout(self): def test_Nin_Nout(self):
"""Test grad is called correctly for a many-to-many op""" """Test grad is called correctly for a many-to-many op"""
gval0 = gof.generic() gval0 = gof.generic()
...@@ -151,6 +148,7 @@ if 0: ...@@ -151,6 +148,7 @@ if 0:
g = _grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval0) self.assertTrue(g[a1.inputs[0]] is gval0)
self.assertTrue(g[a1.inputs[1]] is gval1) self.assertTrue(g[a1.inputs[1]] is gval1)
def test_some_None_ograds(self): def test_some_None_ograds(self):
"""Test grad is called when some output gradients are None""" """Test grad is called when some output gradients are None"""
class O(gof.op.Op): class O(gof.op.Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论