提交 238d2a84 authored 作者: carriepl's avatar carriepl

Merge pull request #2526 from julianser/Fix_2454

First attempt at fixing issue 2454: making consider_constant() functiona...
...@@ -1881,12 +1881,18 @@ def _is_zero(x): ...@@ -1881,12 +1881,18 @@ def _is_zero(x):
class ConsiderConstant(ViewOp): class ConsiderConstant(ViewOp):
def grad(self, args, g_outs): def grad(self, args, g_outs):
return [g_out.zeros_like(g_out) for g_out in g_outs] return [g_out.zeros_like(g_out) for g_out in g_outs]
consider_constant_ = ConsiderConstant() consider_constant_ = ConsiderConstant()
#I create a function only to have the doc show well. # I create a function only to have the doc show well.
def consider_constant(x): def consider_constant(x):
""" Consider an expression constant when computing gradients.
"""
DEPRECATED: use zero_grad() or disconnected_grad() instead.
Consider an expression constant when computing gradients.
The expression itself is unaffected, but when its gradient is The expression itself is unaffected, but when its gradient is
computed, or the gradient of another expression that this computed, or the gradient of another expression that this
...@@ -1901,9 +1907,71 @@ def consider_constant(x): ...@@ -1901,9 +1907,71 @@ def consider_constant(x):
.. versionadded:: 0.6.1 .. versionadded:: 0.6.1
""" """
warnings.warn((
"consider_constant() is deprecated, use zero_grad() or "
"disconnected_grad() instead."), stacklevel=3)
return consider_constant_(x) return consider_constant_(x)
class ZeroGrad(ViewOp):
def grad(self, args, g_outs):
return [g_out.zeros_like(g_out) for g_out in g_outs]
zero_grad_ = ZeroGrad()
def zero_grad(x):
"""
Consider an expression constant when computing gradients.
The expression itself is unaffected, but when its gradient is
computed, or the gradient of another expression that this
expression is a subexpression of, it will be backpropagated
through with a value of zero. In other words, the gradient of
the expression is truncated to 0.
:param x: A Theano expression whose gradient should be truncated.
:return: The expression is returned unmodified, but its gradient
is now truncated to 0.
"""
return zero_grad_(x)
class DisconnectedGrad(ViewOp):
def grad(self, args, g_outs):
return [disconnected_type() for g_out in g_outs]
disconnected_grad_ = DisconnectedGrad()
def disconnected_grad(x):
"""
Consider an expression constant when computing gradients,
while effectively not backpropagating through it.
The expression itself is unaffected, but when its gradient is
computed, or the gradient of another expression that this
expression is a subexpression of, it will not be backpropagated
through. This is effectively equivalent to truncating the gradient
expression to 0, but is executed faster than zero_grad(), which stilll
has to go through the underlying computational graph related to the
expression.
:param x: A Theano expression whose gradient should not be
backpropagated through.
:return: The expression is returned unmodified, but its gradient
is now effectively truncated to 0.
"""
return disconnected_grad_(x)
class GradClip(ViewOp): class GradClip(ViewOp):
# See doc in user fct grad_clip # See doc in user fct grad_clip
__props__ = () __props__ = ()
......
...@@ -5587,11 +5587,17 @@ else: ...@@ -5587,11 +5587,17 @@ else:
# # Remove consider_constant # # # Remove consider_constant #
# ############################ # ############################
# Although the op just returns its input, it should be removed from # Although the ops ConsiderConstant, ZeroGrad and DisconnectedGrad
# the graph to make sure all possible optimizations can be applied. # just returns the input, it should be removed from the graph to
# make sure all possible optimizations can be applied.
register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_), register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_),
'fast_compile', 'fast_run', name='remove_consider_constant') 'fast_compile', 'fast_run', name='remove_consider_constant')
register_canonicalize(gof.OpRemove(theano.gradient.zero_grad_),
'fast_compile', 'fast_run', name='remove_zero_grad')
register_canonicalize(gof.OpRemove(theano.gradient.disconnected_grad_),
'fast_compile', 'fast_run', name='remove_disconnected_grad')
@register_canonicalize @register_canonicalize
@gof.local_optimizer([theano.gradient.GradClip]) @gof.local_optimizer([theano.gradient.GradClip])
......
...@@ -556,7 +556,7 @@ def test_disconnected_cost_grad(): ...@@ -556,7 +556,7 @@ def test_disconnected_cost_grad():
except theano.gradient.DisconnectedInputError: except theano.gradient.DisconnectedInputError:
return return
raise AssertionError("A disconnected gradient has been ignored.") raise AssertionError("A disconnected gradient has been ignored.")
def test_subgraph_grad(): def test_subgraph_grad():
# Tests that the grad method with no known_grads # Tests that the grad method with no known_grads
...@@ -618,12 +618,12 @@ class TestConsiderConstant(unittest.TestCase): ...@@ -618,12 +618,12 @@ class TestConsiderConstant(unittest.TestCase):
# theano.gradient.consider_constant is a wrapper function! # theano.gradient.consider_constant is a wrapper function!
assert gradient.consider_constant_ not in \ assert gradient.consider_constant_ not in \
[node.op for node in f.maker.fgraph.toposort()] [node.op for node in f.maker.fgraph.toposort()]
def test_grad(self): def test_grad(self):
T = theano.tensor T = theano.tensor
a = np.asarray(self.rng.randn(5, 5), a = np.asarray(self.rng.randn(5, 5),
dtype=config.floatX) dtype=config.floatX)
x = T.matrix('x') x = T.matrix('x')
expressions_gradients = [ expressions_gradients = [
...@@ -643,6 +643,111 @@ class TestConsiderConstant(unittest.TestCase): ...@@ -643,6 +643,111 @@ class TestConsiderConstant(unittest.TestCase):
assert np.allclose(f(a), f2(a)) assert np.allclose(f(a), f2(a))
class TestZeroGrad(unittest.TestCase):
def setUp(self):
utt.seed_rng()
self.rng = np.random.RandomState(seed=utt.fetch_seed())
def test_op_removed(self):
x = theano.tensor.matrix('x')
y = x * gradient.zero_grad(x)
f = theano.function([x], y)
# need to refer to theano.gradient.zero_grad here,
# theano.gradient.zero_grad is a wrapper function!
assert gradient.zero_grad_ not in \
[node.op for node in f.maker.fgraph.toposort()]
def test_grad(self):
T = theano.tensor
a = np.asarray(self.rng.randn(5, 5),
dtype=config.floatX)
x = T.matrix('x')
expressions_gradients = [
(x * gradient.zero_grad(x), x),
(x * gradient.zero_grad(T.exp(x)), T.exp(x)),
(gradient.zero_grad(x), T.constant(0.)),
(x**2 * gradient.zero_grad(x), 2 * x**2),
]
for expr, expr_grad in expressions_gradients:
g = gradient.grad(expr.sum(), x)
# gradient according to theano
f = theano.function([x], g, on_unused_input='ignore')
# desired gradient
f2 = theano.function([x], expr_grad, on_unused_input='ignore')
assert np.allclose(f(a), f2(a))
class TestDisconnectedGrad(unittest.TestCase):
def setUp(self):
utt.seed_rng()
self.rng = np.random.RandomState(seed=utt.fetch_seed())
def test_op_removed(self):
x = theano.tensor.matrix('x')
y = x * gradient.disconnected_grad(x)
f = theano.function([x], y)
# need to refer to theano.gradient.disconnected_grad here,
# theano.gradient.disconnected_grad is a wrapper function!
assert gradient.disconnected_grad_ not in \
[node.op for node in f.maker.fgraph.toposort()]
def test_grad(self):
T = theano.tensor
a = np.asarray(self.rng.randn(5, 5),
dtype=config.floatX)
x = T.matrix('x')
expressions_gradients = [
(x * gradient.disconnected_grad(x), x),
(x * gradient.disconnected_grad(T.exp(x)), T.exp(x)),
(x**2 * gradient.disconnected_grad(x), 2 * x**2),
]
for expr, expr_grad in expressions_gradients:
g = gradient.grad(expr.sum(), x)
# gradient according to theano
f = theano.function([x], g, on_unused_input='ignore')
# desired gradient
f2 = theano.function([x], expr_grad, on_unused_input='ignore')
assert np.allclose(f(a), f2(a))
def test_disconnected_paths(self):
# Test that taking gradient going through a disconnected
# path rasises an exception
T = theano.tensor
a = np.asarray(self.rng.randn(5, 5),
dtype=config.floatX)
x = T.matrix('x')
# This MUST raise a DisconnectedInputError error.
# This also rasies an additional warning from gradients.py.
self.assertRaises(gradient.DisconnectedInputError, gradient.grad,
gradient.disconnected_grad(x).sum(), x)
# This MUST NOT raise a DisconnectedInputError error.
y = gradient.grad((x + gradient.disconnected_grad(x)).sum(), x)
a = T.matrix('a')
b = T.matrix('b')
y = a + gradient.disconnected_grad(b)
# This MUST raise a DisconnectedInputError error.
# This also rasies an additional warning from gradients.py.
self.assertRaises(gradient.DisconnectedInputError,
gradient.grad, y.sum(), b)
# This MUST NOT raise a DisconnectedInputError error.
z = gradient.grad(y.sum(), a)
def test_grad_clip(): def test_grad_clip():
x = theano.tensor.scalar() x = theano.tensor.scalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论