提交 6acf4e67 authored 作者: abergeron's avatar abergeron

Merge pull request #2403 from nouiz/gradient_clipping

add GradClip op
...@@ -1888,3 +1888,46 @@ def consider_constant(x): ...@@ -1888,3 +1888,46 @@ def consider_constant(x):
.. versionadded:: 0.6.1 .. versionadded:: 0.6.1
""" """
return consider_constant_(x) return consider_constant_(x)
class GradClip(theano.compile.ViewOp):
# See doc in user fct grad_clip
__props__ = ()
def __init__(self, clip_lower_bound, clip_upper_bound):
# We do not put those member in __eq__ or __hash__
# as they do not influence the perform of this op.
self.clip_lower_bound = clip_lower_bound
self.clip_upper_bound = clip_upper_bound
assert(self.clip_upper_bound >= self.clip_lower_bound)
def grad(self, args, g_outs):
return [theano.tensor.clip(g_out, self.clip_lower_bound,
self.clip_upper_bound)
for g_out in g_outs]
def grad_clip(x, lower_bound, upper_bound):
"""
This op do a view in the forward, but clip the gradient.
This is an elemwise operation.
:param x: the variable we want its gradient inputs clipped
:param lower_bound: The lower bound of the gradient value
:param upper_bound: The upper bound of the gradient value.
:examples:
x = theano.tensor.scalar()
z = theano.tensor.grad(grad_clip(x)**2, x)
z2 = theano.tensor.grad(x**2, x)
f = theano.function([x], outputs = [z, z2])
print f(2.0) # output (1.0, 4.0)
:note: We register an opt in tensor/opt.py that remove the GradClip.
So it have 0 cost in the forward and only do work in the grad.
"""
return GradClip(lower_bound, upper_bound)(x)
...@@ -5551,3 +5551,10 @@ else: ...@@ -5551,3 +5551,10 @@ else:
# the graph to make sure all possible optimizations can be applied. # the graph to make sure all possible optimizations can be applied.
register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_), register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_),
'fast_compile', 'fast_run', name='remove_consider_constant') 'fast_compile', 'fast_run', name='remove_consider_constant')
@register_canonicalize
@gof.local_optimizer([theano.gradient.GradClip])
def local_grad_clip(node):
if isinstance(node.op, theano.gradient.GradClip):
return node.inputs
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
# UNIT TEST # UNIT TEST
# #
import unittest import unittest
import numpy as np
import theano import theano
from theano import gof from theano import gof
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -10,7 +13,6 @@ from theano.tests import unittest_tools as utt ...@@ -10,7 +13,6 @@ from theano.tests import unittest_tools as utt
from theano import gradient from theano import gradient
from theano.tensor.nnet.Conv3D import conv3D from theano.tensor.nnet.Conv3D import conv3D
from theano import config from theano import config
import numpy as np
from theano.gof.null_type import NullType from theano.gof.null_type import NullType
one = theano.tensor.as_tensor_variable(1.) one = theano.tensor.as_tensor_variable(1.)
...@@ -641,5 +643,21 @@ class TestConsiderConstant(unittest.TestCase): ...@@ -641,5 +643,21 @@ class TestConsiderConstant(unittest.TestCase):
assert np.allclose(f(a), f2(a)) assert np.allclose(f(a), f2(a))
def test_grad_clip():
x = theano.tensor.scalar()
z = theano.tensor.grad(gradient.grad_clip(x, -1, 1)**2, x)
z2 = theano.tensor.grad(x**2, x)
f = theano.function([x], outputs=[z, z2])
if theano.config.mode != "FAST_COMPILE":
topo = f.maker.fgraph.toposort()
assert not any([isinstance(node.op, gradient.GradClip)
for node in topo])
out = f(2.)
assert np.allclose(out, (1, 4))
assert not np.allclose(out[0], out[1])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论