提交 f16c8763 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #1654 from lamblin/fix_mrg_grad_none

Use gradient_undefined instead of None in MRG uniform
...@@ -10,6 +10,7 @@ import warnings ...@@ -10,6 +10,7 @@ import warnings
import numpy import numpy
from theano import Op, Apply, shared, config, Variable from theano import Op, Apply, shared, config, Variable
from theano import gradient
from theano import tensor from theano import tensor
from theano.tensor import (raw_random, TensorType, as_tensor_variable, from theano.tensor import (raw_random, TensorType, as_tensor_variable,
get_vector_length, cast, opt, scal) get_vector_length, cast, opt, scal)
...@@ -175,7 +176,10 @@ class mrg_uniform_base(Op): ...@@ -175,7 +176,10 @@ class mrg_uniform_base(Op):
[rstate.type(), self.output_type()]) [rstate.type(), self.output_type()])
def grad(self, inputs, ograd): def grad(self, inputs, ograd):
return [None for i in inputs] return [gradient.grad_undefined(
self, k, inp,
'No gradient defined through random sampling op')
for k, inp in enumerate(inputs)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None for i in eval_points] return [None for i in eval_points]
......
...@@ -755,3 +755,19 @@ def test_random_state_transfer(): ...@@ -755,3 +755,19 @@ def test_random_state_transfer():
su2[0].set_value(su1[0].get_value()) su2[0].set_value(su1[0].get_value())
numpy.testing.assert_array_almost_equal(f1(), f2(), decimal=6) numpy.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
def test_gradient_scan():
# Test for a crash when using MRG inside scan and taking the gradient
# See https://groups.google.com/d/msg/theano-dev/UbcYyU5m-M8/UO9UgXqnQP0J
theano_rng = MRG_RandomStreams(10)
w = theano.shared(numpy.ones(1, dtype='float32'))
def one_step(x):
return x + theano_rng.uniform((1,), dtype='float32') * w
x = tensor.vector(dtype='float32')
values, updates = theano.scan(one_step, outputs_info=x, n_steps=10)
gw = theano.grad(tensor.sum(values[-1]), w)
f = theano.function([x], gw)
f(numpy.arange(1, dtype='float32'))
...@@ -20,6 +20,7 @@ from itertools import izip ...@@ -20,6 +20,7 @@ from itertools import izip
import numpy import numpy
import theano import theano
from theano.compat import exc_message
from theano.compile import function, Param, Out from theano.compile import function, Param, Out
from theano import compile from theano import compile
from theano import gradient from theano import gradient
...@@ -31,6 +32,7 @@ from theano import tensor ...@@ -31,6 +32,7 @@ from theano import tensor
from theano.tensor.opt import Shape_i from theano.tensor.opt import Shape_i
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.gradient import NullType
from theano.compile.profiling import ScanProfileStats from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils from theano.scan_module import scan_utils
...@@ -1305,7 +1307,9 @@ class Scan(PureOp): ...@@ -1305,7 +1307,9 @@ class Scan(PureOp):
known_grads={y: g_y}, wrt=x) known_grads={y: g_y}, wrt=x)
except gradient.NullTypeGradError: except gradient.NullTypeGradError:
# It means the gradient is undefined (which implies # It means the gradient is undefined (which implies
# is connected) # is connected).
# Warning: x is not the right gradient here, but the only
# thing we will check later is whether it is None.
gmp[x] = x gmp[x] = x
except gradient.DisconnectedInputError: except gradient.DisconnectedInputError:
gmp[x] = None gmp[x] = None
...@@ -1476,16 +1480,33 @@ class Scan(PureOp): ...@@ -1476,16 +1480,33 @@ class Scan(PureOp):
if (x in diff_inputs) and if (x in diff_inputs) and
(connection_pattern[ (connection_pattern[
get_inp_idx(self_inputs.index(x))][odx])] get_inp_idx(self_inputs.index(x))][odx])]
grads = gradient.grad( gmp = OrderedDict()
for x in wrt:
try:
gmp[x] = gradient.grad(
cost=None, cost=None,
known_grads={y: g_y}, known_grads={y: g_y},
wrt=wrt, wrt=x,
consider_constant=wrt, consider_constant=wrt,
disconnected_inputs='ignore', disconnected_inputs='ignore',
return_disconnected='None') return_disconnected='None')
gmp = dict(zip(wrt, grads)) except gradient.NullTypeGradError, e:
# The gradient wrt that particular input is undefined.
# This is not necessarily an issue, because maybe that
# particular input is not in the path between the
# "cost" and "wrt" of the external, initial call to grad().
# We simply return a Null gradient, forwarding the message.
gmp[x] = NullType((
"This variable is Null because the grad method on the "
"inner graph of the Scan node %s returned Null for "
"the corresponding inner input variable. The original "
"message was: %s"
% (str(self), exc_message(e))))()
rval = [gmp.get(p, None) for p in diff_inputs] rval = [gmp.get(p, None) for p in diff_inputs]
return rval return rval
dC_dinps_t = [None for inp in diff_inputs] dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs] disconnected_dC_dinps_t = [True for inp in diff_inputs]
dC_dXts = [] dC_dXts = []
...@@ -1528,8 +1549,17 @@ class Scan(PureOp): ...@@ -1528,8 +1549,17 @@ class Scan(PureOp):
for jdx in xrange(len(_dC_dinps_t)): for jdx in xrange(len(_dC_dinps_t)):
if dC_dinps_t[jdx] is None: if dC_dinps_t[jdx] is None:
dC_dinps_t[jdx] = _dC_dinps_t[jdx] dC_dinps_t[jdx] = _dC_dinps_t[jdx]
elif isinstance(dC_dinps_t[jdx].type, NullType):
# The accumulated gradient is undefined
pass
elif _dC_dinps_t[jdx]: elif _dC_dinps_t[jdx]:
if isinstance(_dC_dinps_t[jdx].type, NullType):
# The accumulated gradient is defined, but the new
# term is undefined. The whole thing has to be undefined.
dC_dinps_t[jdx] = _dC_dinps_t[jdx]
else:
dC_dinps_t[jdx] += _dC_dinps_t[jdx] dC_dinps_t[jdx] += _dC_dinps_t[jdx]
# mask inputs that get no gradients # mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)): for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]: if not dC_dinps_t[dx]:
...@@ -1551,12 +1581,20 @@ class Scan(PureOp): ...@@ -1551,12 +1581,20 @@ class Scan(PureOp):
opos = self.get_output_pos(pos) opos = self.get_output_pos(pos)
if opos >= 0: if opos >= 0:
dC_dXtm1s.append(safe_new(dC_dXts[opos])) dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if x.dtype != dC_dXts[opos].dtype: if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \ dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype) x.astype(dC_dXts[opos].dtype)
else: else:
dC_dXtm1s.append(safe_new(x)) dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s): for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType):
# The accumulated gradient is undefined
pass
elif isinstance(dC_dXtm1.type, NullType):
# The new gradient is undefined, this makes the accumulated
# gradient undefined as weell
dC_dinps_t[dx + self.n_seqs] = dC_dXtm1
else:
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1 dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op # Construct scan op
# Seqs # Seqs
...@@ -1634,7 +1672,8 @@ class Scan(PureOp): ...@@ -1634,7 +1672,8 @@ class Scan(PureOp):
outer_inp_mitmot.append(dC_douts[idx][::-1]) outer_inp_mitmot.append(dC_douts[idx][::-1])
mitmot_inp_taps.append([]) mitmot_inp_taps.append([])
mitmot_out_taps.append([]) mitmot_out_taps.append([])
undefined = False undefined_msg = None
through_shared = False
disconnected = True disconnected = True
for jdx in xrange(len(self.mit_mot_out_slices[idx])): for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_inp_mitmot.append(dC_dXts[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
...@@ -1644,21 +1683,33 @@ class Scan(PureOp): ...@@ -1644,21 +1683,33 @@ class Scan(PureOp):
for jdx in xrange(len(self.tap_array[idx])): for jdx in xrange(len(self.tap_array[idx])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
undefined_msg = dC_dinps_t[ins_pos].type.why_null
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos]) inner_out_mitmot.append(dC_dinps_t[ins_pos])
if not disconnected_dC_dinps_t[ins_pos]: if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False disconnected = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True through_shared = True
n_mitmot_inps += 1 n_mitmot_inps += 1
ins_pos += 1 ins_pos += 1
n_mitmot_outs += 1 n_mitmot_outs += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx]) mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
mitmot_out_taps[idx].append(-self.tap_array[idx][jdx]) mitmot_out_taps[idx].append(-self.tap_array[idx][jdx])
if undefined:
type_outs.append('undefined') if undefined_msg:
type_outs.append(undefined_msg)
elif through_shared:
type_outs.append('through_shared')
elif disconnected: elif disconnected:
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
...@@ -1673,12 +1724,23 @@ class Scan(PureOp): ...@@ -1673,12 +1724,23 @@ class Scan(PureOp):
inner_inp_mitmot.append(dC_dXts[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
out_pos += 1 out_pos += 1
n_mitmot_inps += 1 n_mitmot_inps += 1
undefined = False undefined_msg = None
through_shared = False
disconnected = True disconnected = True
mitmot_inp_taps[idx + offset].append(0) mitmot_inp_taps[idx + offset].append(0)
for jdx in xrange(len(self.tap_array[idx_tap])): for jdx in xrange(len(self.tap_array[idx_tap])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
undefined_msg = dC_dinps_t[ins_pos].type.why_null
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos]) inner_out_mitmot.append(dC_dinps_t[ins_pos])
mitmot_inp_taps[idx + offset].append( mitmot_inp_taps[idx + offset].append(
-self.tap_array[idx_tap][jdx]) -self.tap_array[idx_tap][jdx])
mitmot_out_taps[idx].append( mitmot_out_taps[idx].append(
...@@ -1687,12 +1749,16 @@ class Scan(PureOp): ...@@ -1687,12 +1749,16 @@ class Scan(PureOp):
disconnected = False disconnected = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True through_shared = True
n_mitmot_inps += 1 n_mitmot_inps += 1
ins_pos += 1 ins_pos += 1
n_mitmot_outs += 1 n_mitmot_outs += 1
if undefined:
type_outs.append('undefined') if undefined_msg:
type_outs.append(undefined_msg)
elif through_shared:
type_outs.append('through_shared')
elif disconnected: elif disconnected:
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
...@@ -1702,19 +1768,39 @@ class Scan(PureOp): ...@@ -1702,19 +1768,39 @@ class Scan(PureOp):
for idx in xrange(self.n_sit_sot): for idx in xrange(self.n_sit_sot):
mitmot_inp_taps.append([0, 1]) mitmot_inp_taps.append([0, 1])
mitmot_out_taps.append([1]) mitmot_out_taps.append([1])
undefined = False through_shared = False
if not isinstance(dC_douts[idx + offset].type, DisconnectedType): if not isinstance(dC_douts[idx + offset].type, DisconnectedType):
outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
else:
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# Cannot use dC_dinps_t[ins_pos].dtype, so we use
# floatX instead, as it is a dummy value that will not
# be used anyway.
outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype=theano.config.floatX))
else: else:
outer_inp_mitmot.append( outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape, tensor.zeros(outs[idx + offset].shape,
dtype=dC_dinps_t[ins_pos].dtype)) dtype=dC_dinps_t[ins_pos].dtype))
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos]) inner_out_mitmot.append(dC_dinps_t[ins_pos])
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True through_shared = True
if undefined:
type_outs.append('undefined') if isinstance(dC_dinps_t[ins_pos].type, NullType):
type_outs.append(dC_dinps_t[ins_pos].type.why_null)
elif through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[ins_pos]: elif disconnected_dC_dinps_t[ins_pos]:
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
...@@ -1734,24 +1820,38 @@ class Scan(PureOp): ...@@ -1734,24 +1820,38 @@ class Scan(PureOp):
inner_out_nitsot = dC_dinps_t[:self.n_seqs] inner_out_nitsot = dC_dinps_t[:self.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:] inner_out_sitsot = dC_dinps_t[ins_pos:]
for _p, vl in enumerate(inner_out_sitsot): for _p, vl in enumerate(inner_out_sitsot):
undefined = False through_shared = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]): if _sh in gof.graph.inputs([vl]):
undefined = True through_shared = True
if undefined: if isinstance(vl.type, NullType):
type_outs.append('undefined') type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
inner_out_sitsot[_p] = tensor.zeros(
diff_inputs[ins_pos + _p].shape,
dtype=theano.config.floatX)
elif through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[_p + ins_pos]: elif disconnected_dC_dinps_t[_p + ins_pos]:
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
type_outs.append('connected') type_outs.append('connected')
for _p, vl in enumerate(inner_out_nitsot): for _p, vl in enumerate(inner_out_nitsot):
undefined = False through_shared = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]): if _sh in gof.graph.inputs([vl]):
undefined = True through_shared = True
if undefined: if isinstance(vl.type, NullType):
type_outs.append('undefined') type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
inner_out_sitsot[_p] = tensor.zeros(
diff_inputs[_p].shape,
dtype=theano.config.floatX)
if through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[_p]: elif disconnected_dC_dinps_t[_p]:
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
...@@ -1806,6 +1906,7 @@ class Scan(PureOp): ...@@ -1806,6 +1906,7 @@ class Scan(PureOp):
inner_gfn_outs = (inner_out_mitmot + inner_gfn_outs = (inner_out_mitmot +
inner_out_sitsot + inner_out_sitsot +
inner_out_nitsot) inner_out_nitsot)
local_op = Scan(inner_gfn_ins, inner_gfn_outs, info) local_op = Scan(inner_gfn_ins, inner_gfn_outs, info)
outputs = local_op(*outer_inputs) outputs = local_op(*outer_inputs)
if type(outputs) not in (list, tuple): if type(outputs) not in (list, tuple):
...@@ -1820,29 +1921,36 @@ class Scan(PureOp): ...@@ -1820,29 +1921,36 @@ class Scan(PureOp):
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
zip(outputs[offset:offset + self.n_seqs], zip(outputs[offset:offset + self.n_seqs],
type_outs[offset:offset + self.n_seqs])): type_outs[offset:offset + self.n_seqs])):
if t == 'disconnected': if t == 'connected':
gradients.append(x[::-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
elif t == 'undefined': elif t == 'through_shared':
gradients.append( gradients.append(
grad_undefined(self, grad_undefined(self,
p + 1, p + 1,
inputs[p + 1], inputs[p + 1],
'Depends on a shared variable')) 'Depends on a shared variable'))
else: else:
gradients.append(x[::-1]) # t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
zip(outputs[:end], type_outs[:end])): zip(outputs[:end], type_outs[:end])):
if t == 'disconnected': if t == 'connected':
gradients.append(x[::-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
elif t == 'undefined': elif t == 'through_shared':
gradients.append( gradients.append(
grad_undefined(self, grad_undefined(self,
p + 1 + self.n_seqs, p + 1 + self.n_seqs,
inputs[p + 1 + self.n_seqs], inputs[p + 1 + self.n_seqs],
'Depends on a shared variable')) 'Depends on a shared variable'))
else: else:
gradients.append(x[::-1]) # t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
start = len(gradients) start = len(gradients)
node = outs[0].owner node = outs[0].owner
...@@ -1868,16 +1976,20 @@ class Scan(PureOp): ...@@ -1868,16 +1976,20 @@ class Scan(PureOp):
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
zip(outputs[begin:end], type_outs[begin:end])): zip(outputs[begin:end], type_outs[begin:end])):
if t == 'disconnected': if t == 'connected':
gradients.append(x[-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
elif t == 'undefined': elif t == 'through_shared':
gradients.append( gradients.append(
grad_undefined(self, grad_undefined(self,
p + begin + 1, p + begin + 1,
inputs[p + begin + 1], inputs[p + begin + 1],
'Depends on a shared variable')) 'Depends on a shared variable'))
else: else:
gradients.append(x[-1]) # t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
# Mask disconnected gradients # Mask disconnected gradients
# Ideally we would want to assert that the gradients we are # Ideally we would want to assert that the gradients we are
# replacing do indeed evaluate to 0, though that is not practical # replacing do indeed evaluate to 0, though that is not practical
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论