提交 fa057306 authored 作者: Vikram's avatar Vikram

R_op for ZeroGrad + tests

上级 56da8ca8
...@@ -1973,6 +1973,11 @@ class ZeroGrad(ViewOp): ...@@ -1973,6 +1973,11 @@ class ZeroGrad(ViewOp):
def grad(self, args, g_outs): def grad(self, args, g_outs):
return [g_out.zeros_like(g_out) for g_out in g_outs] return [g_out.zeros_like(g_out) for g_out in g_outs]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return theano.tensor.zeros(1)
zero_grad_ = ZeroGrad() zero_grad_ = ZeroGrad()
......
...@@ -663,7 +663,7 @@ class TestZeroGrad(unittest.TestCase): ...@@ -663,7 +663,7 @@ class TestZeroGrad(unittest.TestCase):
x = theano.tensor.matrix('x') x = theano.tensor.matrix('x')
y = x * gradient.zero_grad(x) y = x * gradient.zero_grad(x)
f = theano.function([x], y) f = theano.function([x], y)
# need to refer to theano.gradient.zero_grad here, # need to refer to theano.ogradient.zero_grad here,
# theano.gradient.zero_grad is a wrapper function! # theano.gradient.zero_grad is a wrapper function!
assert gradient.zero_grad_ not in \ assert gradient.zero_grad_ not in \
[node.op for node in f.maker.fgraph.toposort()] [node.op for node in f.maker.fgraph.toposort()]
...@@ -691,6 +691,23 @@ class TestZeroGrad(unittest.TestCase): ...@@ -691,6 +691,23 @@ class TestZeroGrad(unittest.TestCase):
assert np.allclose(f(a), f2(a)) assert np.allclose(f(a), f2(a))
def test_rop(self):
T = theano.tensor
x = T.vector()
v = T.vector()
y = gradient.zero_grad(x)
rop = T.Rop(y, x, v)
f = theano.function([x, v], rop, on_unused_input='ignore')
a = np.asarray(self.rng.randn(5),
dtype=config.floatX)
u = np.asarray(self.rng.randn(5),
dtype=config.floatX)
assert np.count_nonzero(f(a,u)) == 0
class TestDisconnectedGrad(unittest.TestCase): class TestDisconnectedGrad(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论