提交 a1703e4a authored 作者: Ying Zhang's avatar Ying Zhang

added documentation and fixed the gradient of the Assert Op

上级 e9328fdd
...@@ -24,6 +24,7 @@ from theano.gof import opt, InconsistencyError, TopoOptimizer, graph ...@@ -24,6 +24,7 @@ from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant from theano.gof import Variable, Constant
from theano.compat.python2x import maxsize from theano.compat.python2x import maxsize
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType
from theano.configparser import config from theano.configparser import config
from theano.tensor.elemwise import Elemwise, DimShuffle from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
...@@ -1487,12 +1488,20 @@ def local_cast_cast(node): ...@@ -1487,12 +1488,20 @@ def local_cast_cast(node):
class Assert(T.Op): class Assert(T.Op):
""" """
Implements assertion in a computational graph. Implements assertion in a computational graph.
Returns the first parameter if the condition is true, otherwise, trigger
AssertionError.
Example:
T = theano.tensor
x = T.vector('x')
assert_op = T.opt.Assert()
func = theano.function([x], assert_op(x, x.size<2))
Notes: Notes:
This Op can be removed from the graph because of optimizations, and can This Op is an debugging feature. It can be removed from the graph
hide some possible optimizations to the optimizer. because of optimizations, and can hide some possible optimizations to the
Also, the output of the Op must be returned by the function computing the optimizer. Also, the output of the Op must be returned by the function
graph, otherwise it will not be used. computing the graph, otherwise it will not be used.
""" """
view_map = {0: [0]} view_map = {0: [0]}
...@@ -1529,7 +1538,7 @@ class Assert(T.Op): ...@@ -1529,7 +1538,7 @@ class Assert(T.Op):
return hash(type(self)) ^ hash(self.msg) return hash(type(self)) ^ hash(self.msg)
def grad(self, input, output_gradients): def grad(self, input, output_gradients):
return output_gradients return output_gradients + [DisconnectedType()()] * (len(inputs) - 1)
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
value = inames[0] value = inames[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论