提交 a1703e4a authored 作者: Ying Zhang's avatar Ying Zhang

added documentation and fixed the gradient of the Assert Op

上级 e9328fdd
......@@ -24,6 +24,7 @@ from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant
from theano.compat.python2x import maxsize
from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType
from theano.configparser import config
from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
......@@ -1487,12 +1488,20 @@ def local_cast_cast(node):
class Assert(T.Op):
"""
Implements assertion in a computational graph.
Returns the first parameter if the condition is true, otherwise, trigger
AssertionError.
Example:
T = theano.tensor
x = T.vector('x')
assert_op = T.opt.Assert()
func = theano.function([x], assert_op(x, x.size<2))
Notes:
This Op can be removed from the graph because of optimizations, and can
hide some possible optimizations to the optimizer.
Also, the output of the Op must be returned by the function computing the
graph, otherwise it will not be used.
This Op is an debugging feature. It can be removed from the graph
because of optimizations, and can hide some possible optimizations to the
optimizer. Also, the output of the Op must be returned by the function
computing the graph, otherwise it will not be used.
"""
view_map = {0: [0]}
......@@ -1529,7 +1538,7 @@ class Assert(T.Op):
return hash(type(self)) ^ hash(self.msg)
def grad(self, input, output_gradients):
return output_gradients
return output_gradients + [DisconnectedType()()] * (len(inputs) - 1)
def c_code(self, node, name, inames, onames, sub):
value = inames[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论