提交 f010aa39 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added tests for undefined gradient functionality

上级 2fc402b3
...@@ -3,12 +3,13 @@ ...@@ -3,12 +3,13 @@
# UNIT TEST # UNIT TEST
# #
import unittest import unittest
import numpy
import theano import theano
from theano import gof from theano import gof
from theano.gradient import * from theano.gradient import grad_sources_inputs
from theano import gradient from theano import gradient
from theano.tensor.nnet.Conv3D import conv3D
from theano import config
def _grad_sources_inputs(*args): def _grad_sources_inputs(*args):
...@@ -265,6 +266,20 @@ def test_unimplemented_grad_func(): ...@@ -265,6 +266,20 @@ def test_unimplemented_grad_func():
except NotImplementedError: except NotImplementedError:
pass pass
def test_undefined_grad_func():
#tests that function compilation catches undefined grads in the graph
a = theano.tensor.vector()
b = theano.gradient.grad_undefined(theano.tensor.add, 0, a)
try:
f = theano.function([a],b)
assert 0
#Note: it's important that the GradUndefinedOp is cauhgt at
#COMPILATION time, not execution time.
#If the uncomputable variable is, for example, multiplied by0,
#it could be optimized out of the final graph
except theano.gradient.GradUndefinedError:
pass
def test_unimplemented_grad_grad(): def test_unimplemented_grad_grad():
#tests that unimplemented grads are caught in the grad method #tests that unimplemented grads are caught in the grad method
...@@ -284,6 +299,24 @@ def test_unimplemented_grad_grad(): ...@@ -284,6 +299,24 @@ def test_unimplemented_grad_grad():
except NotImplementedError: except NotImplementedError:
pass pass
def test_undefined_grad_grad():
#tests that undefined grads are caught in the grad method
V = theano.tensor.TensorType(dtype=config.floatX,
broadcastable = (False,False,False,False,False))()
W = theano.tensor.TensorType(dtype=config.floatX,
broadcastable = (False, False, False, False, False))()
b = theano.tensor.vector()
d = theano.tensor.ivector()
Z = conv3D(V,W,b,d)
try:
g = theano.gradient.grad(Z.sum(),d)
assert False
except theano.gradient.GradUndefinedError:
pass
def test_grad_name(): def test_grad_name():
A = theano.tensor.matrix('A') A = theano.tensor.matrix('A')
x = theano.tensor.vector('x') x = theano.tensor.vector('x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论