提交 0b186aec authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Explicitly calls .sum() to get a scalar cost before computing gradients.

Closes #539.
上级 c4b8a8d6
...@@ -278,10 +278,11 @@ def test_mlp(): ...@@ -278,10 +278,11 @@ def test_mlp():
classifier = MLP( rng = rng, input=x, n_in=28*28, n_hidden = 500, n_out=10) classifier = MLP( rng = rng, input=x, n_in=28*28, n_hidden = 500, n_out=10)
# the cost we minimize during training is the negative log likelihood of # the cost we minimize during training is the negative log likelihood of
# the model # the model.
cost = classifier.negative_log_likelihood(y) # We take the mean of the cost over each minibatch.
cost = classifier.negative_log_likelihood(y).mean()
# compute the gradient of cost with respect to theta (sotred in params) # compute the gradient of cost with respect to theta (stored in params)
# the resulting gradients will be stored in a list gparams # the resulting gradients will be stored in a list gparams
gparams = [] gparams = []
for param in classifier.params: for param in classifier.params:
......
...@@ -926,7 +926,7 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -926,7 +926,7 @@ class T_max_and_argmax(unittest.TestCase):
utt.verify_grad(lambda v: max_and_argmax(v,axis=[0])[0], [data]) utt.verify_grad(lambda v: max_and_argmax(v,axis=[0])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v,axis=[0])[1], [data]) utt.verify_grad(lambda v: max_and_argmax(v,axis=[0])[1], [data])
check_grad_max(data,eval_outputs(grad(max_and_argmax(n,axis=0)[0],n)),axis=0) check_grad_max(data,eval_outputs(grad(max_and_argmax(n,axis=0)[0].sum(),n)),axis=0)
utt.verify_grad(lambda v: max_and_argmax(v,axis=[1])[0], [data]) utt.verify_grad(lambda v: max_and_argmax(v,axis=[1])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v,axis=[1])[1], [data]) utt.verify_grad(lambda v: max_and_argmax(v,axis=[1])[1], [data])
...@@ -1241,7 +1241,7 @@ class T_min_max(unittest.TestCase): ...@@ -1241,7 +1241,7 @@ class T_min_max(unittest.TestCase):
utt.verify_grad(lambda v: max(v), [data]) utt.verify_grad(lambda v: max(v), [data])
utt.verify_grad(lambda v: max(v,axis=[0]), [data]) utt.verify_grad(lambda v: max(v,axis=[0]), [data])
check_grad_max(data,eval_outputs(grad(max(n,axis=0),n)),axis=0) check_grad_max(data,eval_outputs(grad(max(n,axis=0).sum(),n)),axis=0)
utt.verify_grad(lambda v: max(v,axis=[1]), [data]) utt.verify_grad(lambda v: max(v,axis=[1]), [data])
#check_grad_max(data,eval_outputs(grad(max(n,axis=1),n)),axis=1) #check_grad_max(data,eval_outputs(grad(max(n,axis=1),n)),axis=1)
...@@ -1273,7 +1273,7 @@ class T_min_max(unittest.TestCase): ...@@ -1273,7 +1273,7 @@ class T_min_max(unittest.TestCase):
utt.verify_grad(lambda v: min(v), [data]) utt.verify_grad(lambda v: min(v), [data])
utt.verify_grad(lambda v: min(v,axis=[0]), [data]) utt.verify_grad(lambda v: min(v,axis=[0]), [data])
check_grad_min(data,eval_outputs(grad(min(n,axis=0),n)),axis=0) check_grad_min(data,eval_outputs(grad(min(n,axis=0).sum(),n)),axis=0)
utt.verify_grad(lambda v: min(v,axis=[1]), [data]) utt.verify_grad(lambda v: min(v,axis=[1]), [data])
#check_grad_min(data,eval_outputs(grad(min(n,axis=1),n)),axis=1) #check_grad_min(data,eval_outputs(grad(min(n,axis=1),n)),axis=1)
......
...@@ -1843,7 +1843,7 @@ class T_local_erfc(unittest.TestCase): ...@@ -1843,7 +1843,7 @@ class T_local_erfc(unittest.TestCase):
mode_fusion = copy.copy(self.mode_fusion) mode_fusion = copy.copy(self.mode_fusion)
mode_fusion.check_isfinite = False mode_fusion.check_isfinite = False
f = theano.function([x],T.grad(T.log(T.erfc(x)),x), mode=mode) f = theano.function([x],T.grad(T.log(T.erfc(x)).sum(),x), mode=mode)
#theano.printing.debugprint(f) #theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==23, len(f.maker.env.nodes) assert len(f.maker.env.nodes)==23, len(f.maker.env.nodes)
assert all(numpy.isfinite(f(val))) assert all(numpy.isfinite(f(val)))
...@@ -1878,13 +1878,13 @@ class T_local_erfc(unittest.TestCase): ...@@ -1878,13 +1878,13 @@ class T_local_erfc(unittest.TestCase):
assert all(numpy.isfinite(f(val))) assert all(numpy.isfinite(f(val)))
#test that it work correctly if x is x*2 in the graph. #test that it work correctly if x is x*2 in the graph.
f = theano.function([x],T.grad(T.log(T.erfc(2*x)),x), mode=mode) f = theano.function([x],T.grad(T.log(T.erfc(2*x)).sum(),x), mode=mode)
#theano.printing.debugprint(f) #theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==23, len(f.maker.env.nodes) assert len(f.maker.env.nodes)==23, len(f.maker.env.nodes)
assert numpy.isfinite(f(val)).all() assert numpy.isfinite(f(val)).all()
assert f.maker.env.outputs[0].dtype==theano.config.floatX assert f.maker.env.outputs[0].dtype==theano.config.floatX
f = theano.function([x],T.grad(T.log(T.erfc(x)),x), mode=mode_fusion) f = theano.function([x],T.grad(T.log(T.erfc(x)).sum(),x), mode=mode_fusion)
assert len(f.maker.env.nodes)==1, len(f.maker.env.nodes) assert len(f.maker.env.nodes)==1, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX assert f.maker.env.outputs[0].dtype==theano.config.floatX
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论