提交 5764be52 authored 作者: Sander Dieleman's avatar Sander Dieleman

added consider_constant op, which truncates the gradient of an expression to 0.…

added consider_constant op, which truncates the gradient of an expression to 0. Also added an optimization to remove it from the graph, so other optimizations can be applied properly. test whether the gradient is computed correctly, and test whether the consider_constant op is removed from the graph.
上级 85209fbb
......@@ -62,4 +62,5 @@ from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \
from theano.tensor.sort import sort, argsort
from theano.tensor.extra_ops import (DiffOp, bincount, squeeze,
repeat, bartlett, fill_diagonal, cumsum, cumprod)
repeat, bartlett, fill_diagonal, cumsum, cumprod,
consider_constant)
......@@ -6,6 +6,8 @@ from theano.tensor import basic
from theano import gof, scalar
tensor = basic
from theano.gradient import DisconnectedType
from theano.compile import ViewOp
from theano.tensor.opt import register_canonicalize
class CumsumOp(theano.Op):
......@@ -723,3 +725,35 @@ def fill_diagonal(a, val):
.. versionadded:: 0.6
"""
return fill_diagonal_(a, val)
class ConsiderConstant(ViewOp):
def grad(self, args, g_outs):
return [tensor.zeros_like(g_out) for g_out in g_outs]
consider_constant_ = ConsiderConstant()
# Although the op just returns its input, it should be removed from
# the graph to make sure all possible optimizations can be applied.
register_canonicalize(gof.OpRemove(consider_constant_),
name='remove_consider_constant')
#I create a function only to have the doc show well.
def consider_constant(x):
""" Consider an expression constant when computing gradients.
The expression itself is unaffected, but when its gradient is
computed, or the gradient of another expression that this
expression is a subexpression of, it will not be backpropagated
through. In other words, the gradient of the expression is
truncated to 0.
:param x: A Theano expression whose gradient should be truncated.
:return: The expression is returned unmodified, but its gradient
is now truncated to 0.
.. versionadded:: 0.6.1
"""
return consider_constant_(x)
import numpy as np
import numpy
import unittest
import theano
from theano.tests import unittest_tools as utt
......@@ -7,7 +8,7 @@ from theano.tests import unittest_tools as utt
from theano.tensor.extra_ops import (CumsumOp, cumsum, CumprodOp, cumprod,
BinCountOp, bincount, DiffOp, diff,
squeeze, RepeatOp, repeat, Bartlett, bartlett,
FillDiagonal, fill_diagonal)
FillDiagonal, fill_diagonal, consider_constant)
from theano import tensor as T
from theano import config, tensor, function
......@@ -463,3 +464,41 @@ class TestFillDiagonal(utt.InferShapeTester):
numpy.random.rand()],
self.op_class,
warn=False)
class TestConsiderConstant(unittest.TestCase):
def setUp(self):
utt.seed_rng()
self.rng = np.random.RandomState(seed=utt.fetch_seed())
def test_op_removed(self):
x = T.matrix('x')
y = x * consider_constant(x)
f = theano.function([x], y)
# need to refer to T.extra_ops.consider_constant_ here,
# T.consider_constant is a wrapper function!
assert T.extra_ops.consider_constant_ not in \
[node.op for node in f.maker.fgraph.toposort()]
def test_grad(self):
a = numpy.asarray(self.rng.randn(50, 50),
dtype=config.floatX)
x = T.matrix('x')
y = x * consider_constant(x)
expressions_gradients = [
(x * consider_constant(x), x),
(x * consider_constant(T.exp(x)), T.exp(x)),
(consider_constant(x), T.constant(0.)),
(x**2 * consider_constant(x), 2 * x**2),
]
for expr, expr_grad in expressions_gradients:
g = T.grad(expr.sum(), x)
f = theano.function([x], g, on_unused_input='ignore') # grad according to theano
f2 = theano.function([x], expr_grad, on_unused_input='ignore') # desired grad
assert np.allclose(f(a), f2(a))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论