提交 71a8d3aa authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added test that the upgrade of connection_pattern works correctly

上级 20a1aee0
...@@ -11,6 +11,8 @@ from theano import gradient ...@@ -11,6 +11,8 @@ from theano import gradient
from theano.tensor.nnet.Conv3D import conv3D from theano.tensor.nnet.Conv3D import conv3D
from theano import config from theano import config
import numpy as np import numpy as np
from theano.gradient import DisconnectedType
from theano.gof.null_type import NullType
one = theano.tensor.as_tensor_variable(1.) one = theano.tensor.as_tensor_variable(1.)
...@@ -316,5 +318,47 @@ def test_grad_disconnected(): ...@@ -316,5 +318,47 @@ def test_grad_disconnected():
assert np.allclose(g,np.ones(x.shape,dtype=x.dtype)) assert np.allclose(g,np.ones(x.shape,dtype=x.dtype))
def test_disconnected_nan():
# test that connection_pattern can prevent getting NaN
# Op1 has two outputs, f and g
# x is connected to f but not to g
class Op1(theano.gof.Op):
def make_node(self, x):
return theano.Apply(self, inputs=[x],
outputs = [ x.type(), theano.tensor.scalar() ])
def connection_pattern(self):
return [[True, False]]
def grad(self, inputs, output_grads):
return [ inputs[0].zeros_like() ]
# Op2 has two inputs, f and g
# Its gradient with respect to g is not defined
class Op2(theano.gof.Op):
def make_node(self, f, g):
return theano.Apply(self, inputs=[f,g],
outputs = [ theano.tensor.scalar() ])
def grad(self, inputs, output_grads):
return [ inputs[0].zeros_like(), NullType()() ]
x = theano.tensor.vector()
f, g = Op1()(x)
cost = Op2()(f,g)
# cost is differentiable wrt x
# but we can't tell that without using Op1's connection pattern
# looking at the theano graph alone, g is an ancestor of cost
# and has x as an ancestor, so we must compute its gradient
g = gradient.grad(cost, x)
# If we made it to here without an exception, then the
# connection_pattern functionality worked correctly
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论