提交 f16c8763 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #1654 from lamblin/fix_mrg_grad_none

Use gradient_undefined instead of None in MRG uniform
...@@ -10,6 +10,7 @@ import warnings ...@@ -10,6 +10,7 @@ import warnings
import numpy import numpy
from theano import Op, Apply, shared, config, Variable from theano import Op, Apply, shared, config, Variable
from theano import gradient
from theano import tensor from theano import tensor
from theano.tensor import (raw_random, TensorType, as_tensor_variable, from theano.tensor import (raw_random, TensorType, as_tensor_variable,
get_vector_length, cast, opt, scal) get_vector_length, cast, opt, scal)
...@@ -175,7 +176,10 @@ class mrg_uniform_base(Op): ...@@ -175,7 +176,10 @@ class mrg_uniform_base(Op):
[rstate.type(), self.output_type()]) [rstate.type(), self.output_type()])
def grad(self, inputs, ograd): def grad(self, inputs, ograd):
return [None for i in inputs] return [gradient.grad_undefined(
self, k, inp,
'No gradient defined through random sampling op')
for k, inp in enumerate(inputs)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None for i in eval_points] return [None for i in eval_points]
......
...@@ -755,3 +755,19 @@ def test_random_state_transfer(): ...@@ -755,3 +755,19 @@ def test_random_state_transfer():
su2[0].set_value(su1[0].get_value()) su2[0].set_value(su1[0].get_value())
numpy.testing.assert_array_almost_equal(f1(), f2(), decimal=6) numpy.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
def test_gradient_scan():
# Test for a crash when using MRG inside scan and taking the gradient
# See https://groups.google.com/d/msg/theano-dev/UbcYyU5m-M8/UO9UgXqnQP0J
theano_rng = MRG_RandomStreams(10)
w = theano.shared(numpy.ones(1, dtype='float32'))
def one_step(x):
return x + theano_rng.uniform((1,), dtype='float32') * w
x = tensor.vector(dtype='float32')
values, updates = theano.scan(one_step, outputs_info=x, n_steps=10)
gw = theano.grad(tensor.sum(values[-1]), w)
f = theano.function([x], gw)
f(numpy.arange(1, dtype='float32'))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论