提交 f16c8763 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #1654 from lamblin/fix_mrg_grad_none

Use gradient_undefined instead of None in MRG uniform
......@@ -10,6 +10,7 @@ import warnings
import numpy
from theano import Op, Apply, shared, config, Variable
from theano import gradient
from theano import tensor
from theano.tensor import (raw_random, TensorType, as_tensor_variable,
get_vector_length, cast, opt, scal)
......@@ -175,7 +176,10 @@ class mrg_uniform_base(Op):
[rstate.type(), self.output_type()])
def grad(self, inputs, ograd):
return [None for i in inputs]
return [gradient.grad_undefined(
self, k, inp,
'No gradient defined through random sampling op')
for k, inp in enumerate(inputs)]
def R_op(self, inputs, eval_points):
return [None for i in eval_points]
......
......@@ -755,3 +755,19 @@ def test_random_state_transfer():
su2[0].set_value(su1[0].get_value())
numpy.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
def test_gradient_scan():
# Test for a crash when using MRG inside scan and taking the gradient
# See https://groups.google.com/d/msg/theano-dev/UbcYyU5m-M8/UO9UgXqnQP0J
theano_rng = MRG_RandomStreams(10)
w = theano.shared(numpy.ones(1, dtype='float32'))
def one_step(x):
return x + theano_rng.uniform((1,), dtype='float32') * w
x = tensor.vector(dtype='float32')
values, updates = theano.scan(one_step, outputs_info=x, n_steps=10)
gw = theano.grad(tensor.sum(values[-1]), w)
f = theano.function([x], gw)
f(numpy.arange(1, dtype='float32'))
......@@ -20,6 +20,7 @@ from itertools import izip
import numpy
import theano
from theano.compat import exc_message
from theano.compile import function, Param, Out
from theano import compile
from theano import gradient
......@@ -31,6 +32,7 @@ from theano import tensor
from theano.tensor.opt import Shape_i
from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType
from theano.gradient import NullType
from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils
......@@ -1305,7 +1307,9 @@ class Scan(PureOp):
known_grads={y: g_y}, wrt=x)
except gradient.NullTypeGradError:
# It means the gradient is undefined (which implies
# is connected)
# is connected).
# Warning: x is not the right gradient here, but the only
# thing we will check later is whether it is None.
gmp[x] = x
except gradient.DisconnectedInputError:
gmp[x] = None
......@@ -1476,16 +1480,33 @@ class Scan(PureOp):
if (x in diff_inputs) and
(connection_pattern[
get_inp_idx(self_inputs.index(x))][odx])]
grads = gradient.grad(
cost=None,
known_grads={y: g_y},
wrt=wrt,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
gmp = dict(zip(wrt, grads))
gmp = OrderedDict()
for x in wrt:
try:
gmp[x] = gradient.grad(
cost=None,
known_grads={y: g_y},
wrt=x,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
except gradient.NullTypeGradError, e:
# The gradient wrt that particular input is undefined.
# This is not necessarily an issue, because maybe that
# particular input is not in the path between the
# "cost" and "wrt" of the external, initial call to grad().
# We simply return a Null gradient, forwarding the message.
gmp[x] = NullType((
"This variable is Null because the grad method on the "
"inner graph of the Scan node %s returned Null for "
"the corresponding inner input variable. The original "
"message was: %s"
% (str(self), exc_message(e))))()
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs]
dC_dXts = []
......@@ -1528,8 +1549,17 @@ class Scan(PureOp):
for jdx in xrange(len(_dC_dinps_t)):
if dC_dinps_t[jdx] is None:
dC_dinps_t[jdx] = _dC_dinps_t[jdx]
elif isinstance(dC_dinps_t[jdx].type, NullType):
# The accumulated gradient is undefined
pass
elif _dC_dinps_t[jdx]:
dC_dinps_t[jdx] += _dC_dinps_t[jdx]
if isinstance(_dC_dinps_t[jdx].type, NullType):
# The accumulated gradient is defined, but the new
# term is undefined. The whole thing has to be undefined.
dC_dinps_t[jdx] = _dC_dinps_t[jdx]
else:
dC_dinps_t[jdx] += _dC_dinps_t[jdx]
# mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]:
......@@ -1551,13 +1581,21 @@ class Scan(PureOp):
opos = self.get_output_pos(pos)
if opos >= 0:
dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if x.dtype != dC_dXts[opos].dtype:
if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype)
else:
dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType):
# The accumulated gradient is undefined
pass
elif isinstance(dC_dXtm1.type, NullType):
# The new gradient is undefined, this makes the accumulated
# gradient undefined as weell
dC_dinps_t[dx + self.n_seqs] = dC_dXtm1
else:
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op
# Seqs
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
......@@ -1634,7 +1672,8 @@ class Scan(PureOp):
outer_inp_mitmot.append(dC_douts[idx][::-1])
mitmot_inp_taps.append([])
mitmot_out_taps.append([])
undefined = False
undefined_msg = None
through_shared = False
disconnected = True
for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_inp_mitmot.append(dC_dXts[out_pos])
......@@ -1644,21 +1683,33 @@ class Scan(PureOp):
for jdx in xrange(len(self.tap_array[idx])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
undefined_msg = dC_dinps_t[ins_pos].type.why_null
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
through_shared = True
n_mitmot_inps += 1
ins_pos += 1
n_mitmot_outs += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
mitmot_out_taps[idx].append(-self.tap_array[idx][jdx])
if undefined:
type_outs.append('undefined')
if undefined_msg:
type_outs.append(undefined_msg)
elif through_shared:
type_outs.append('through_shared')
elif disconnected:
type_outs.append('disconnected')
else:
......@@ -1673,12 +1724,23 @@ class Scan(PureOp):
inner_inp_mitmot.append(dC_dXts[out_pos])
out_pos += 1
n_mitmot_inps += 1
undefined = False
undefined_msg = None
through_shared = False
disconnected = True
mitmot_inp_taps[idx + offset].append(0)
for jdx in xrange(len(self.tap_array[idx_tap])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
undefined_msg = dC_dinps_t[ins_pos].type.why_null
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
mitmot_inp_taps[idx + offset].append(
-self.tap_array[idx_tap][jdx])
mitmot_out_taps[idx].append(
......@@ -1687,12 +1749,16 @@ class Scan(PureOp):
disconnected = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
through_shared = True
n_mitmot_inps += 1
ins_pos += 1
n_mitmot_outs += 1
if undefined:
type_outs.append('undefined')
if undefined_msg:
type_outs.append(undefined_msg)
elif through_shared:
type_outs.append('through_shared')
elif disconnected:
type_outs.append('disconnected')
else:
......@@ -1702,26 +1768,46 @@ class Scan(PureOp):
for idx in xrange(self.n_sit_sot):
mitmot_inp_taps.append([0, 1])
mitmot_out_taps.append([1])
undefined = False
through_shared = False
if not isinstance(dC_douts[idx + offset].type, DisconnectedType):
outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
else:
outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype=dC_dinps_t[ins_pos].dtype))
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# Cannot use dC_dinps_t[ins_pos].dtype, so we use
# floatX instead, as it is a dummy value that will not
# be used anyway.
outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype=theano.config.floatX))
else:
outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype=dC_dinps_t[ins_pos].dtype))
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
if undefined:
type_outs.append('undefined')
through_shared = True
if isinstance(dC_dinps_t[ins_pos].type, NullType):
type_outs.append(dC_dinps_t[ins_pos].type.why_null)
elif through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[ins_pos]:
type_outs.append('disconnected')
else:
type_outs.append('connected')
inner_inp_mitmot += [dC_dXts[out_pos],
dC_dXtm1s[ins_pos - self.n_seqs]]
dC_dXtm1s[ins_pos - self.n_seqs]]
n_mitmot_outs += 1
out_pos += 1
ins_pos += 1
......@@ -1734,24 +1820,38 @@ class Scan(PureOp):
inner_out_nitsot = dC_dinps_t[:self.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:]
for _p, vl in enumerate(inner_out_sitsot):
undefined = False
through_shared = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]):
undefined = True
if undefined:
type_outs.append('undefined')
through_shared = True
if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
inner_out_sitsot[_p] = tensor.zeros(
diff_inputs[ins_pos + _p].shape,
dtype=theano.config.floatX)
elif through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[_p + ins_pos]:
type_outs.append('disconnected')
else:
type_outs.append('connected')
for _p, vl in enumerate(inner_out_nitsot):
undefined = False
through_shared = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]):
undefined = True
if undefined:
type_outs.append('undefined')
through_shared = True
if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
inner_out_sitsot[_p] = tensor.zeros(
diff_inputs[_p].shape,
dtype=theano.config.floatX)
if through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[_p]:
type_outs.append('disconnected')
else:
......@@ -1790,12 +1890,12 @@ class Scan(PureOp):
info['mode'] = self.mode
outer_inputs = ([grad_steps] +
outer_inp_seqs +
outer_inp_mitmot +
outer_inp_sitsot +
[inputs[0] for x in xrange(n_nit_sot)] +
self.outer_shared(inputs) +
self.outer_non_seqs(inputs))
outer_inp_seqs +
outer_inp_mitmot +
outer_inp_sitsot +
[inputs[0] for x in xrange(n_nit_sot)] +
self.outer_shared(inputs) +
self.outer_non_seqs(inputs))
inner_other_args = self_inputs[offset:]
inner_gfn_ins = (inner_inp_seqs +
......@@ -1806,6 +1906,7 @@ class Scan(PureOp):
inner_gfn_outs = (inner_out_mitmot +
inner_out_sitsot +
inner_out_nitsot)
local_op = Scan(inner_gfn_ins, inner_gfn_outs, info)
outputs = local_op(*outer_inputs)
if type(outputs) not in (list, tuple):
......@@ -1820,29 +1921,36 @@ class Scan(PureOp):
for p, (x, t) in enumerate(
zip(outputs[offset:offset + self.n_seqs],
type_outs[offset:offset + self.n_seqs])):
if t == 'disconnected':
if t == 'connected':
gradients.append(x[::-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
elif t == 'through_shared':
gradients.append(
grad_undefined(self,
p + 1,
inputs[p + 1],
'Depends on a shared variable'))
else:
gradients.append(x[::-1])
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate(
zip(outputs[:end], type_outs[:end])):
if t == 'disconnected':
if t == 'connected':
gradients.append(x[::-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
elif t == 'through_shared':
gradients.append(
grad_undefined(self,
p + 1 + self.n_seqs,
inputs[p + 1 + self.n_seqs],
'Depends on a shared variable'))
else:
gradients.append(x[::-1])
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
start = len(gradients)
node = outs[0].owner
......@@ -1868,16 +1976,20 @@ class Scan(PureOp):
end = begin + n_sitsot_outs
for p, (x, t) in enumerate(
zip(outputs[begin:end], type_outs[begin:end])):
if t == 'disconnected':
if t == 'connected':
gradients.append(x[-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
elif t == 'through_shared':
gradients.append(
grad_undefined(self,
p + begin + 1,
inputs[p + begin + 1],
'Depends on a shared variable'))
else:
gradients.append(x[-1])
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
# Mask disconnected gradients
# Ideally we would want to assert that the gradients we are
# replacing do indeed evaluate to 0, though that is not practical
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论