提交 d9671761 authored 作者: abergeron's avatar abergeron

Merge pull request #2720 from yingzha/ccw

added documentation and fixed the gradient of the Assert Op
......@@ -24,6 +24,7 @@ from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant
from theano.compat.python2x import maxsize
from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType
from theano.configparser import config
from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
......@@ -1492,11 +1493,22 @@ class Assert(T.Op):
"""
Implements assertion in a computational graph.
Returns the first parameter if the condition is true, otherwise, triggers
AssertionError.
Example:
T = theano.tensor
x = T.vector('x')
assert_op = T.opt.Assert()
func = theano.function([x], assert_op(x, x.size<2))
Notes:
This Op can be removed from the graph because of optimizations, and can
hide some possible optimizations to the optimizer.
Also, the output of the Op must be returned by the function computing the
graph, otherwise it will not be used.
This Op is a debugging feature. It can be removed from the graph
because of optimizations, and can hide some possible optimizations to
the optimizer. Specifically, removing happens if it can be determined
that condition will always be true. Also, the output of the Op must be
used in the function computing the graph, but it doesn't have to be
returned.
"""
view_map = {0: [0]}
......@@ -1533,7 +1545,7 @@ class Assert(T.Op):
return hash(type(self)) ^ hash(self.msg)
def grad(self, input, output_gradients):
return output_gradients
return output_gradients + [DisconnectedType()()] * (len(input) - 1)
def c_code(self, node, name, inames, onames, sub):
value = inames[0]
......
import numpy
import theano
import theano.tensor as T
from theano.tensor.opt import Assert
def test_assert_op_gradient():
x = T.vector('x')
assert_op = Assert()
cost = T.sum(assert_op(x, x.size < 2))
grad = T.grad(cost, x)
func = theano.function([x], grad)
x_val = numpy.ones(shape=(1,))
assert func(x_val) == 1
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论