提交 71a8d3aa authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added test that the upgrade of connection_pattern works correctly

上级 20a1aee0
......@@ -11,6 +11,8 @@ from theano import gradient
from theano.tensor.nnet.Conv3D import conv3D
from theano import config
import numpy as np
from theano.gradient import DisconnectedType
from theano.gof.null_type import NullType
one = theano.tensor.as_tensor_variable(1.)
......@@ -316,5 +318,47 @@ def test_grad_disconnected():
assert np.allclose(g,np.ones(x.shape,dtype=x.dtype))
def test_disconnected_nan():
# test that connection_pattern can prevent getting NaN
# Op1 has two outputs, f and g
# x is connected to f but not to g
class Op1(theano.gof.Op):
def make_node(self, x):
return theano.Apply(self, inputs=[x],
outputs = [ x.type(), theano.tensor.scalar() ])
def connection_pattern(self):
return [[True, False]]
def grad(self, inputs, output_grads):
return [ inputs[0].zeros_like() ]
# Op2 has two inputs, f and g
# Its gradient with respect to g is not defined
class Op2(theano.gof.Op):
def make_node(self, f, g):
return theano.Apply(self, inputs=[f,g],
outputs = [ theano.tensor.scalar() ])
def grad(self, inputs, output_grads):
return [ inputs[0].zeros_like(), NullType()() ]
x = theano.tensor.vector()
f, g = Op1()(x)
cost = Op2()(f,g)
# cost is differentiable wrt x
# but we can't tell that without using Op1's connection pattern
# looking at the theano graph alone, g is an ancestor of cost
# and has x as an ancestor, so we must compute its gradient
g = gradient.grad(cost, x)
# If we made it to here without an exception, then the
# connection_pattern functionality worked correctly
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论