提交 02514c6e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Try to fix how Null gradients are propagated in Scan

- propagate NullType when summing different gradient contributions - identify which outer outputs of the scan node implementing the grad should be Null - replace them, and the corresponding inner variables, with zeros of the appropriate size, so the scan node can be compiled and run. - introduce a difference in undefined gradient because going through a shared variable with updates, and because a NullType was returned by the internal graph.
上级 5fa90044
......@@ -20,6 +20,7 @@ from itertools import izip
import numpy
import theano
from theano.compat import exc_message
from theano.compile import function, Param, Out
from theano import compile
from theano import gradient
......@@ -31,6 +32,7 @@ from theano import tensor
from theano.tensor.opt import Shape_i
from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType
from theano.gradient import NullType
from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils
......@@ -1487,13 +1489,18 @@ class Scan(PureOp):
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
except gradient.NullTypeGradError:
except gradient.NullTypeGradError, e:
# The gradient wrt that particular input is undefined.
# This is not necessarily an issue, because maybe that
# particular input is not in the path between the
# "cost" and "wrt" of the external, initial call to grad().
# We simply forward the Null gradient.
gmp[x] = x
# We simply return a Null gradient, forwarding the message.
gmp[x] = NullType((
"This variable is Null because the grad method on the "
"inner graph of the Scan node %s returned Null for "
"the corresponding inner input variable. The original "
"message was: %s"
% (str(self), exc_message(e))))()
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
......@@ -1540,8 +1547,17 @@ class Scan(PureOp):
for jdx in xrange(len(_dC_dinps_t)):
if dC_dinps_t[jdx] is None:
dC_dinps_t[jdx] = _dC_dinps_t[jdx]
elif isinstance(dC_dinps_t[jdx].type, NullType):
# The accumulated gradient is undefined
pass
elif _dC_dinps_t[jdx]:
dC_dinps_t[jdx] += _dC_dinps_t[jdx]
if isinstance(_dC_dinps_t[jdx].type, NullType):
# The accumulated gradient is defined, but the new
# term is undefined. The whole thing has to be undefined.
dC_dinps_t[jdx] = _dC_dinps_t[jdx]
else:
dC_dinps_t[jdx] += _dC_dinps_t[jdx]
# mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]:
......@@ -1563,13 +1579,21 @@ class Scan(PureOp):
opos = self.get_output_pos(pos)
if opos >= 0:
dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if x.dtype != dC_dXts[opos].dtype:
if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype)
else:
dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType):
# The accumulated gradient is undefined
pass
elif isinstance(dC_dXtm1.type, NullType):
# The new gradient is undefined, this makes the accumulated
# gradient undefined as weell
dC_dinps_t[dx + self.n_seqs] = dC_dXtm1
else:
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op
# Seqs
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
......@@ -1646,7 +1670,8 @@ class Scan(PureOp):
outer_inp_mitmot.append(dC_douts[idx][::-1])
mitmot_inp_taps.append([])
mitmot_out_taps.append([])
undefined = False
undefined_msg = None
through_shared = False
disconnected = True
for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_inp_mitmot.append(dC_dXts[out_pos])
......@@ -1656,21 +1681,33 @@ class Scan(PureOp):
for jdx in xrange(len(self.tap_array[idx])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
undefined_msg = dC_dinps_t[ins_pos].type.why_null
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
through_shared = True
n_mitmot_inps += 1
ins_pos += 1
n_mitmot_outs += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
mitmot_out_taps[idx].append(-self.tap_array[idx][jdx])
if undefined:
type_outs.append('undefined')
if undefined_msg:
type_outs.append(undefined_msg)
elif through_shared:
type_outs.append('through_shared')
elif disconnected:
type_outs.append('disconnected')
else:
......@@ -1685,12 +1722,23 @@ class Scan(PureOp):
inner_inp_mitmot.append(dC_dXts[out_pos])
out_pos += 1
n_mitmot_inps += 1
undefined = False
undefined_msg = None
through_shared = False
disconnected = True
mitmot_inp_taps[idx + offset].append(0)
for jdx in xrange(len(self.tap_array[idx_tap])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
undefined_msg = dC_dinps_t[ins_pos].type.why_null
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
mitmot_inp_taps[idx + offset].append(
-self.tap_array[idx_tap][jdx])
mitmot_out_taps[idx].append(
......@@ -1699,12 +1747,16 @@ class Scan(PureOp):
disconnected = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
through_shared = True
n_mitmot_inps += 1
ins_pos += 1
n_mitmot_outs += 1
if undefined:
type_outs.append('undefined')
if undefined_msg:
type_outs.append(undefined_msg)
elif through_shared:
type_outs.append('through_shared')
elif disconnected:
type_outs.append('disconnected')
else:
......@@ -1714,19 +1766,39 @@ class Scan(PureOp):
for idx in xrange(self.n_sit_sot):
mitmot_inp_taps.append([0, 1])
mitmot_out_taps.append([1])
undefined = False
through_shared = False
if not isinstance(dC_douts[idx + offset].type, DisconnectedType):
outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
else:
outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype=dC_dinps_t[ins_pos].dtype))
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# Cannot use dC_dinps_t[ins_pos].dtype, so we use
# floatX instead, as it is a dummy value that will not
# be used anyway.
outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype=theano.config.floatX))
else:
outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype=dC_dinps_t[ins_pos].dtype))
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
if undefined:
type_outs.append('undefined')
through_shared = True
if isinstance(dC_dinps_t[ins_pos].type, NullType):
type_outs.append(dC_dinps_t[ins_pos].type.why_null)
elif through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[ins_pos]:
type_outs.append('disconnected')
else:
......@@ -1746,24 +1818,38 @@ class Scan(PureOp):
inner_out_nitsot = dC_dinps_t[:self.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:]
for _p, vl in enumerate(inner_out_sitsot):
undefined = False
through_shared = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]):
undefined = True
if undefined:
type_outs.append('undefined')
through_shared = True
if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
inner_out_sitsot[_p] = tensor.zeros(
diff_inputs[ins_pos + _p].shape,
dtype=theano.config.floatX)
elif through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[_p + ins_pos]:
type_outs.append('disconnected')
else:
type_outs.append('connected')
for _p, vl in enumerate(inner_out_nitsot):
undefined = False
through_shared = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]):
undefined = True
if undefined:
type_outs.append('undefined')
through_shared = True
if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
inner_out_sitsot[_p] = tensor.zeros(
diff_inputs[_p].shape,
dtype=theano.config.floatX)
if through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[_p]:
type_outs.append('disconnected')
else:
......@@ -1832,29 +1918,36 @@ class Scan(PureOp):
for p, (x, t) in enumerate(
zip(outputs[offset:offset + self.n_seqs],
type_outs[offset:offset + self.n_seqs])):
if t == 'disconnected':
if t == 'connected':
gradients.append(x[::-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
elif t == 'through_shared':
gradients.append(
grad_undefined(self,
p + 1,
inputs[p + 1],
'Depends on a shared variable'))
else:
gradients.append(x[::-1])
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate(
zip(outputs[:end], type_outs[:end])):
if t == 'disconnected':
if t == 'connected':
gradients.append(x[::-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
elif t == 'through_shared':
gradients.append(
grad_undefined(self,
p + 1 + self.n_seqs,
inputs[p + 1 + self.n_seqs],
'Depends on a shared variable'))
else:
gradients.append(x[::-1])
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
start = len(gradients)
node = outs[0].owner
......@@ -1880,16 +1973,20 @@ class Scan(PureOp):
end = begin + n_sitsot_outs
for p, (x, t) in enumerate(
zip(outputs[begin:end], type_outs[begin:end])):
if t == 'disconnected':
if t == 'connected':
gradients.append(x[-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()())
elif t == 'undefined':
elif t == 'through_shared':
gradients.append(
grad_undefined(self,
p + begin + 1,
inputs[p + begin + 1],
'Depends on a shared variable'))
else:
gradients.append(x[-1])
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
# Mask disconnected gradients
# Ideally we would want to assert that the gradients we are
# replacing do indeed evaluate to 0, though that is not practical
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论