提交 d9671761 authored 作者: abergeron's avatar abergeron

Merge pull request #2720 from yingzha/ccw

added documentation and fixed the gradient of the Assert Op
...@@ -24,6 +24,7 @@ from theano.gof import opt, InconsistencyError, TopoOptimizer, graph ...@@ -24,6 +24,7 @@ from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant from theano.gof import Variable, Constant
from theano.compat.python2x import maxsize from theano.compat.python2x import maxsize
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType
from theano.configparser import config from theano.configparser import config
from theano.tensor.elemwise import Elemwise, DimShuffle from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
...@@ -1492,11 +1493,22 @@ class Assert(T.Op): ...@@ -1492,11 +1493,22 @@ class Assert(T.Op):
""" """
Implements assertion in a computational graph. Implements assertion in a computational graph.
Returns the first parameter if the condition is true, otherwise, triggers
AssertionError.
Example:
T = theano.tensor
x = T.vector('x')
assert_op = T.opt.Assert()
func = theano.function([x], assert_op(x, x.size<2))
Notes: Notes:
This Op can be removed from the graph because of optimizations, and can This Op is a debugging feature. It can be removed from the graph
hide some possible optimizations to the optimizer. because of optimizations, and can hide some possible optimizations to
Also, the output of the Op must be returned by the function computing the the optimizer. Specifically, removing happens if it can be determined
graph, otherwise it will not be used. that condition will always be true. Also, the output of the Op must be
used in the function computing the graph, but it doesn't have to be
returned.
""" """
view_map = {0: [0]} view_map = {0: [0]}
...@@ -1533,7 +1545,7 @@ class Assert(T.Op): ...@@ -1533,7 +1545,7 @@ class Assert(T.Op):
return hash(type(self)) ^ hash(self.msg) return hash(type(self)) ^ hash(self.msg)
def grad(self, input, output_gradients): def grad(self, input, output_gradients):
return output_gradients return output_gradients + [DisconnectedType()()] * (len(input) - 1)
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
value = inames[0] value = inames[0]
......
import numpy
import theano
import theano.tensor as T
from theano.tensor.opt import Assert
def test_assert_op_gradient():
x = T.vector('x')
assert_op = Assert()
cost = T.sum(assert_op(x, x.size < 2))
grad = T.grad(cost, x)
func = theano.function([x], grad)
x_val = numpy.ones(shape=(1,))
assert func(x_val) == 1
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论