提交 cb94334d authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made unimplemented and undefined grads handled by NaNType

上级 122d7246
...@@ -21,6 +21,7 @@ from theano.gof import Variable ...@@ -21,6 +21,7 @@ from theano.gof import Variable
from theano.gof.python25 import all from theano.gof.python25 import all
import theano.gof.utils import theano.gof.utils
tensor = None tensor = None
from theano.gof.nan_type import NaNType
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients' _msg_badlen = 'op.grad(...) returned wrong number of gradients'
...@@ -193,32 +194,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -193,32 +194,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
gmap[r] = g_r gmap[r] = g_r
return gmap return gmap
class GradNotImplementedOp(gof.op.UncomputableOp):
""" An UncomputableOp representing a gradient that hasn't been implemented yet.
"""
def __init__(self, op, x_pos, comment = ""):
"""
op: A theano op whose grad is not implemented for some input
x_pos: An int, giving the index in the op's input list of
a variable for which the gradient is not implemented
(if op has unimplemented gradients for several inputs,
it must still return a separate UnimplementedGradOp for
each)
comment: An optional comment explaining why the gradient isn't
implemented.
"""
assert isinstance(op, gof.Op)
assert isinstance(x_pos, int)
assert x_pos >= 0
super(GradNotImplementedOp,self).__init__(NotImplementedError,
"%s does not implement its gradient with respect to input %d. %s" \
% (str(type(op)), x_pos, comment))
def grad_not_implemented(op, x_pos, x, comment = ""): def grad_not_implemented(op, x_pos, x, comment = ""):
""" """
Return an un-computable symbolic variable of type `x.type`. Return an un-computable symbolic variable of type `x.type`.
...@@ -233,38 +208,9 @@ def grad_not_implemented(op, x_pos, x, comment = ""): ...@@ -233,38 +208,9 @@ def grad_not_implemented(op, x_pos, x, comment = ""):
gradient is not implemented. gradient is not implemented.
""" """
return GradNotImplementedOp(op, x_pos, comment)(x) return NaNType("This variable is NaN because the grad method for " + \
"input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \
class GradUndefinedError(Exception): "not implemented.")()
""" An exception raised upon attempts to use an undefined gradient.
"""
class GradUndefinedOp(gof.op.UncomputableOp):
""" An UncomputableOp representing a gradient that is mathematically
undefined.
"""
def __init__(self, op, x_pos, comment = ""):
"""
op: A theano op whose grad is mathematically undefined for
some input
x_pos: An int, giving the index in the op's input list of
a variable for which the gradient is undefined
(if op has undefined gradients for several inputs,
it must still return a separate GradUndefinedOp for
each)
comment: An optional comment explaining why the gradient isn't
defined.
"""
assert isinstance(op, gof.Op)
assert isinstance(x_pos, int)
assert x_pos >= 0
super(GradUndefinedOp,self).__init__(GradUndefinedError,
"%s does not implement its gradient with respect to input %d. %s" \
% (str(type(op)), x_pos, comment))
def grad_undefined(op, x_pos, x, comment = ""): def grad_undefined(op, x_pos, x, comment = ""):
""" """
...@@ -280,7 +226,9 @@ def grad_undefined(op, x_pos, x, comment = ""): ...@@ -280,7 +226,9 @@ def grad_undefined(op, x_pos, x, comment = ""):
gradient is not defined. gradient is not defined.
""" """
return GradUndefinedOp(op, x_pos, comment)(x) return NaNType("This variable is NaN because the gradient for " + \
"input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \
"mathematically undefined.")()
...@@ -503,6 +451,11 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -503,6 +451,11 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
if tensor is None: if tensor is None:
from theano import tensor from theano import tensor
if isinstance(cost.type, NaNType):
raise ValueError("Can't differentiate a NaN cost. cost is NaN because "+\
cost.type.why_nan)
if consider_constant is None: if consider_constant is None:
consider_constant = [] consider_constant = []
else: else:
...@@ -593,6 +546,9 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -593,6 +546,9 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
term_dict[node] = node.op.grad(node.inputs, term_dict[node] = node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs]) [access_grad_cache(var) for var in node.outputs])
for i in xrange(len(term_dict[node])): for i in xrange(len(term_dict[node])):
if isinstance(term_dict[node][i].type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\
term_dict[node][i].type.why_nan)
if term_dict[node][i] is None: if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i]) term_dict[node][i] = tensor.zeros_like(node.inputs[i])
return term_dict[node] return term_dict[node]
......
...@@ -2217,6 +2217,7 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -2217,6 +2217,7 @@ class T_argmin_argmax(unittest.TestCase):
def test_grad_argmin(self): def test_grad_argmin(self):
data = rand(2, 3) data = rand(2, 3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
n.name = 'n'
#test grad of argmin #test grad of argmin
utt.verify_grad(lambda v: argmin(v, axis=-1), [data]) utt.verify_grad(lambda v: argmin(v, axis=-1), [data])
...@@ -2228,7 +2229,11 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -2228,7 +2229,11 @@ class T_argmin_argmax(unittest.TestCase):
utt.verify_grad(lambda v: argmin(v.flatten()), [data]) utt.verify_grad(lambda v: argmin(v.flatten()), [data])
try: try:
grad(argmin(n, axis=-1), n) cost = argmin(n, axis=-1)
cost.name = None
g = grad(cost, n)
from theano.printing import min_informative_str
print min_informative_str(g)
raise Exception('Expected an error') raise Exception('Expected an error')
except TypeError: except TypeError:
pass pass
......
...@@ -6,264 +6,269 @@ import unittest ...@@ -6,264 +6,269 @@ import unittest
import theano import theano
from theano import gof from theano import gof
from theano.gradient import grad_sources_inputs #from theano.gradient import grad_sources_inputs
from theano import gradient from theano import gradient
from theano.tensor.nnet.Conv3D import conv3D from theano.tensor.nnet.Conv3D import conv3D
from theano import config from theano import config
def _grad_sources_inputs(*args): #def _grad_sources_inputs(*args):
# warn_type was introduced after this code, it complains throughout for nothing. # warn_type was introduced after this code, it complains throughout for nothing.
return grad_sources_inputs(warn_type=False, *args) # return grad_sources_inputs(warn_type=False, *args)
class test_grad_sources_inputs(unittest.TestCase): if 0:
def test_retNone1(self): #most of these tests are no longer relevant now that grad_sources_inputs is gone
"""Test that it is not ok to return None from op.grad()""" #also, some of our policies about what is allowed or not have changed
class retNone(gof.op.Op): #nonetheless, it may be a good idea to resurrect some of these tests and write
def make_node(self): #them in terms of tensor.grad directly
inputs = [gof.generic()] class test_grad_sources_inputs(unittest.TestCase):
outputs = [gof.generic()] def test_retNone1(self):
return gof.Apply(self, inputs, outputs) """Test that it is not ok to return None from op.grad()"""
def grad(self, inp, grads): class retNone(gof.op.Op):
x, = inp def make_node(self):
gz, = grads inputs = [gof.generic()]
pass outputs = [gof.generic()]
a = retNone().make_node() return gof.Apply(self, inputs, outputs)
try: def grad(self, inp, grads):
_grad_sources_inputs([(a.out, 1)], None) x, = inp
except ValueError, e: gz, = grads
self.assertTrue(e[0] is gradient._msg_retType) pass
return a = retNone().make_node()
self.fail() try:
def test_retNone1_b(self): _grad_sources_inputs([(a.out, 1)], None)
"""Test that it is ok to return [None] from op.grad()""" except ValueError, e:
class retNone(gof.op.Op): self.assertTrue(e[0] is gradient._msg_retType)
def make_node(self, *inputs): return
outputs = [gof.generic()] self.fail()
return gof.Apply(self, inputs, outputs) def test_retNone1_b(self):
def grad(self, inp, grads): """Test that it is ok to return [None] from op.grad()"""
return [None] class retNone(gof.op.Op):
i = gof.generic() def make_node(self, *inputs):
a = retNone().make_node(i) outputs = [gof.generic()]
g = _grad_sources_inputs([(a.out, 1)], None) return gof.Apply(self, inputs, outputs)
self.assertTrue(not i in g) def grad(self, inp, grads):
return [None]
i = gof.generic()
a = retNone().make_node(i)
g = _grad_sources_inputs([(a.out, 1)], None)
self.assertTrue(not i in g)
def test_wrong_rval_len1(self): def test_wrong_rval_len1(self):
"""Test that it is not ok to return the wrong number of gradients""" """Test that it is not ok to return the wrong number of gradients"""
class retNone(gof.op.Op): class retNone(gof.op.Op):
def make_node(self, *inputs): def make_node(self, *inputs):
outputs = [gof.generic()] outputs = [gof.generic()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inputs, grads): def grad(self, inputs, grads):
return [None] return [None]
i = gof.generic() i = gof.generic()
j = gof.generic() j = gof.generic()
a1 = retNone().make_node(i) a1 = retNone().make_node(i)
g = _grad_sources_inputs([(a1.out, 1)], None) g = _grad_sources_inputs([(a1.out, 1)], None)
a2 = retNone().make_node(i,j) a2 = retNone().make_node(i,j)
try: try:
g = _grad_sources_inputs([(a2.out, 1)], None) g = _grad_sources_inputs([(a2.out, 1)], None)
except ValueError, e: except ValueError, e:
self.assertTrue(e[0] is gradient._msg_badlen) self.assertTrue(e[0] is gradient._msg_badlen)
return return
self.fail() self.fail()
def test_stop_on_all_none(self): def test_stop_on_all_none(self):
"""Test that op.grad() is not called when output grads are all None""" """Test that op.grad() is not called when output grads are all None"""
class retNone(gof.op.Op): class retNone(gof.op.Op):
def __init__(self, tst): def __init__(self, tst):
self.tst = tst self.tst = tst
def make_node(self, *inputs): def make_node(self, *inputs):
outputs = [gof.generic()] outputs = [gof.generic()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inputs, grads): def grad(self, inputs, grads):
self.tst.fail() self.tst.fail()
i = gof.generic() i = gof.generic()
a1 = retNone(self).make_node(i) a1 = retNone(self).make_node(i)
g = _grad_sources_inputs([(a1.out, None)], None) g = _grad_sources_inputs([(a1.out, None)], None)
def test_1in_1out(self): def test_1in_1out(self):
"""Test grad is called correctly for a 1-to-1 op""" """Test grad is called correctly for a 1-to-1 op"""
gval = gof.generic() gval = gof.generic()
class O(gof.op.Op): class O(gof.op.Op):
def make_node(self): def make_node(self):
inputs = [gof.generic()] inputs = [gof.generic()]
outputs = [gof.generic()] outputs = [gof.generic()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads): def grad(self, inp, grads):
return gval, return gval,
a1 = O().make_node() a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval) self.assertTrue(g[a1.inputs[0]] is gval)
def test_1in_Nout(self): def test_1in_Nout(self):
"""Test grad is called correctly for a 1-to-many op""" """Test grad is called correctly for a 1-to-many op"""
gval = gof.generic() gval = gof.generic()
class O(gof.op.Op): class O(gof.op.Op):
def make_node(self): def make_node(self):
inputs = [gof.generic()] inputs = [gof.generic()]
outputs = [gof.generic(),gof.generic()] outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
gz1, gz2 = grads gz1, gz2 = grads
return gval, return gval,
a1 = O().make_node() a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval) self.assertTrue(g[a1.inputs[0]] is gval)
def test_Nin_1out(self): def test_Nin_1out(self):
"""Test grad is called correctly for a many-to-1 op""" """Test grad is called correctly for a many-to-1 op"""
gval0 = gof.generic() gval0 = gof.generic()
gval1 = gof.generic() gval1 = gof.generic()
class O(gof.op.Op): class O(gof.op.Op):
def make_node(self): def make_node(self):
inputs = [gof.generic(),gof.generic()] inputs = [gof.generic(),gof.generic()]
outputs = [gof.generic()] outputs = [gof.generic()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads): def grad(self, inp, grads):
x0, x1 = inp x0, x1 = inp
gz, = grads gz, = grads
return (gval0, gval1) return (gval0, gval1)
a1 = O().make_node() a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval0) self.assertTrue(g[a1.inputs[0]] is gval0)
self.assertTrue(g[a1.inputs[1]] is gval1) self.assertTrue(g[a1.inputs[1]] is gval1)
def test_Nin_Nout(self): def test_Nin_Nout(self):
"""Test grad is called correctly for a many-to-many op""" """Test grad is called correctly for a many-to-many op"""
gval0 = gof.generic() gval0 = gof.generic()
gval1 = gof.generic() gval1 = gof.generic()
class O(gof.op.Op): class O(gof.op.Op):
def make_node(self): def make_node(self):
inputs = [gof.generic(),gof.generic()] inputs = [gof.generic(),gof.generic()]
outputs = [gof.generic(),gof.generic()] outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inp, grads): def grad(self, inp, grads):
return gval0, gval1 return gval0, gval1
a1 = O().make_node() a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], 1)], None) g = _grad_sources_inputs([(a1.outputs[0], 1)], None)
self.assertTrue(g[a1.inputs[0]] is gval0) self.assertTrue(g[a1.inputs[0]] is gval0)
self.assertTrue(g[a1.inputs[1]] is gval1) self.assertTrue(g[a1.inputs[1]] is gval1)
def test_some_None_ograds(self): def test_some_None_ograds(self):
"""Test grad is called when some output gradients are None""" """Test grad is called when some output gradients are None"""
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self, tst): def __init__(self, tst):
self.tst = tst self.tst = tst
def make_node(self, *inputs): def make_node(self, *inputs):
outputs = [gof.generic(),gof.generic()] outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inputs, g_out): def grad(self, inputs, g_out):
return [1] return [1]
i = gof.generic() i = gof.generic()
a1 = O(self).make_node(i) a1 = O(self).make_node(i)
g = grad_sources_inputs([(a1.outputs[0], 1)], None, warn_type=False) g = grad_sources_inputs([(a1.outputs[0], 1)], None, warn_type=False)
self.assertTrue(g[i] is 1) self.assertTrue(g[i] is 1)
def test_some_None_igrads(self): def test_some_None_igrads(self):
"""Test that traversal works properly when an op return some None""" """Test that traversal works properly when an op return some None"""
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self, tst, grad_ok): def __init__(self, tst, grad_ok):
self.tst = tst self.tst = tst
self.grad_ok = grad_ok self.grad_ok = grad_ok
def make_node(self, *inputs): def make_node(self, *inputs):
outputs = [gof.generic(),gof.generic()] outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inputs, g_out): def grad(self, inputs, g_out):
if not self.grad_ok: if not self.grad_ok:
self.tst.fail() self.tst.fail()
else: else:
return [1, None] return [1, None]
i = gof.generic() i = gof.generic()
j = gof.generic() j = gof.generic()
k = gof.generic() k = gof.generic()
a1 = O(self, True).make_node(i,j) a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(a1.outputs[1], k) a2 = O(self, True).make_node(a1.outputs[1], k)
g = grad_sources_inputs([(a2.outputs[0], 1)], None, warn_type=False) g = grad_sources_inputs([(a2.outputs[0], 1)], None, warn_type=False)
self.assertTrue(g[i] is 1 and j not in g and k not in g) self.assertTrue(g[i] is 1 and j not in g and k not in g)
a1 = O(self, True).make_node(i,j) a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(k, a1.outputs[1]) a2 = O(self, True).make_node(k, a1.outputs[1])
g = _grad_sources_inputs([(a2.outputs[0], 1)], None) g = _grad_sources_inputs([(a2.outputs[0], 1)], None)
self.assertTrue(g[k] is 1 and i not in g and j not in g) self.assertTrue(g[k] is 1 and i not in g and j not in g)
def test_inputs(self): def test_inputs(self):
"""Test that passing inputs shortens the traversal""" """Test that passing inputs shortens the traversal"""
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self, tst, grad_ok): def __init__(self, tst, grad_ok):
self.tst = tst self.tst = tst
self.grad_ok = grad_ok self.grad_ok = grad_ok
def make_node(self, *inputs): def make_node(self, *inputs):
outputs = [gof.generic(),gof.generic()] outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inputs, grads): def grad(self, inputs, grads):
g0, g1 = grads g0, g1 = grads
if not self.grad_ok: if not self.grad_ok:
self.tst.fail() self.tst.fail()
else:
if g1:
return [g0, g0+g1]
else: else:
return [g0, g0] if g1:
i = gof.generic() return [g0, g0+g1]
j = gof.generic() else:
k = gof.generic() return [g0, g0]
a1 = O(self, True).make_node(i,j) i = gof.generic()
a2 = O(self, True).make_node(k,a1.outputs[1]) j = gof.generic()
g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4), k = gof.generic()
(a1.outputs[0], 3), (a1.outputs[0], 3)], a1.outputs) a1 = O(self, True).make_node(i,j)
self.assertTrue(g[a2.inputs[0]] == 1) a2 = O(self, True).make_node(k,a1.outputs[1])
self.assertTrue(g[a2.inputs[1]] == 5) g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
self.assertTrue(g[a1.outputs[0]] == 6) (a1.outputs[0], 3), (a1.outputs[0], 3)], a1.outputs)
self.assertTrue(g[a1.outputs[1]] == 5) self.assertTrue(g[a2.inputs[0]] == 1)
self.assertTrue(a1.inputs[0] not in g) self.assertTrue(g[a2.inputs[1]] == 5)
self.assertTrue(a1.inputs[1] not in g) self.assertTrue(g[a1.outputs[0]] == 6)
self.assertTrue(g[a1.outputs[1]] == 5)
self.assertTrue(a1.inputs[0] not in g)
self.assertTrue(a1.inputs[1] not in g)
def test_multiple_sources(self): def test_multiple_sources(self):
"""Test that passing multiple sources works""" """Test that passing multiple sources works"""
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self, tst, grad_ok): def __init__(self, tst, grad_ok):
self.tst = tst self.tst = tst
self.grad_ok = grad_ok self.grad_ok = grad_ok
def make_node(self, *inputs): def make_node(self, *inputs):
outputs = [gof.generic(),gof.generic()] outputs = [gof.generic(),gof.generic()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def grad(self, inputs, grads): def grad(self, inputs, grads):
g0, g1 = grads g0, g1 = grads
if not self.grad_ok: if not self.grad_ok:
self.tst.fail() self.tst.fail()
else:
if g1:
return [g0, g0+g1]
else: else:
return [g0, g0] if g1:
i = gof.generic() return [g0, g0+g1]
j = gof.generic() else:
k = gof.generic() return [g0, g0]
a1 = O(self,True).make_node(i,j) i = gof.generic()
a2 = O(self,True).make_node(k,a1.outputs[1]) j = gof.generic()
g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4), k = gof.generic()
(a1.outputs[0], 3), (a1.outputs[0], 3)], None) a1 = O(self,True).make_node(i,j)
self.assertTrue(g[a2.inputs[0]] == 1) a2 = O(self,True).make_node(k,a1.outputs[1])
self.assertTrue(g[a2.inputs[1]] == 5) g = _grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
self.assertTrue(g[a1.outputs[0]] == 6) (a1.outputs[0], 3), (a1.outputs[0], 3)], None)
self.assertTrue(g[a1.outputs[1]] == 5) self.assertTrue(g[a2.inputs[0]] == 1)
self.assertTrue(g[a1.inputs[0]] == 6) self.assertTrue(g[a2.inputs[1]] == 5)
self.assertTrue(g[a1.inputs[1]] == 11) self.assertTrue(g[a1.outputs[0]] == 6)
self.assertTrue(g[a1.outputs[1]] == 5)
self.assertTrue(g[a1.inputs[0]] == 6)
self.assertTrue(g[a1.inputs[1]] == 11)
def test_unimplemented_grad_func(): def test_unimplemented_grad_func():
#tests that function compilation catches unimplemented grads in the graph #tests that function compilation catches unimplemented grads in the graph
a = theano.tensor.vector() a = theano.tensor.vector()
b = theano.gradient.grad_not_implemented(theano.tensor.add, 0, a) b = theano.gradient.grad_not_implemented(theano.tensor.add, 0, a)
try: try:
f = theano.function([a], b) f = theano.function([a], b, on_unused_input = 'ignore')
assert 0 assert 0
#Note: it's important that the NotImplementedGradOp is caught #Note: it's important that the NotImplementedGradOp is caught
#at COMPILATION time, not execution time. #at COMPILATION time, not execution time.
#If the uncomputable variable is, for example, multiplied by 0, #If the uncomputable variable is, for example, multiplied by 0,
#it could be optimized out of the final graph. #it could be optimized out of the final graph.
except NotImplementedError: except TypeError:
pass pass
def test_undefined_grad_func(): def test_undefined_grad_func():
...@@ -271,13 +276,13 @@ def test_undefined_grad_func(): ...@@ -271,13 +276,13 @@ def test_undefined_grad_func():
a = theano.tensor.vector() a = theano.tensor.vector()
b = theano.gradient.grad_undefined(theano.tensor.add, 0, a) b = theano.gradient.grad_undefined(theano.tensor.add, 0, a)
try: try:
f = theano.function([a],b) f = theano.function([a],b, on_unused_input = 'ignore')
assert 0 assert 0
#Note: it's important that the GradUndefinedOp is cauhgt at #Note: it's important that the GradUndefinedOp is caught at
#COMPILATION time, not execution time. #COMPILATION time, not execution time.
#If the uncomputable variable is, for example, multiplied by0, #If the uncomputable variable is, for example, multiplied by0,
#it could be optimized out of the final graph #it could be optimized out of the final graph
except theano.gradient.GradUndefinedError: except TypeError:
pass pass
def test_unimplemented_grad_grad(): def test_unimplemented_grad_grad():
...@@ -296,7 +301,7 @@ def test_unimplemented_grad_grad(): ...@@ -296,7 +301,7 @@ def test_unimplemented_grad_grad():
try: try:
g = theano.gradient.grad(b,a) g = theano.gradient.grad(b,a)
assert False assert False
except NotImplementedError: except TypeError:
pass pass
def test_undefined_grad_grad(): def test_undefined_grad_grad():
...@@ -314,7 +319,7 @@ def test_undefined_grad_grad(): ...@@ -314,7 +319,7 @@ def test_undefined_grad_grad():
try: try:
g = theano.gradient.grad(Z.sum(),d) g = theano.gradient.grad(Z.sum(),d)
assert False assert False
except theano.gradient.GradUndefinedError: except TypeError:
pass pass
def test_grad_name(): def test_grad_name():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论