提交 02514c6e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Try to fix how Null gradients are propagated in Scan

- propagate NullType when summing different gradient contributions - identify which outer outputs of the scan node implementing the grad should be Null - replace them, and the corresponding inner variables, with zeros of the appropriate size, so the scan node can be compiled and run. - introduce a difference in undefined gradient because going through a shared variable with updates, and because a NullType was returned by the internal graph.
上级 5fa90044
...@@ -20,6 +20,7 @@ from itertools import izip ...@@ -20,6 +20,7 @@ from itertools import izip
import numpy import numpy
import theano import theano
from theano.compat import exc_message
from theano.compile import function, Param, Out from theano.compile import function, Param, Out
from theano import compile from theano import compile
from theano import gradient from theano import gradient
...@@ -31,6 +32,7 @@ from theano import tensor ...@@ -31,6 +32,7 @@ from theano import tensor
from theano.tensor.opt import Shape_i from theano.tensor.opt import Shape_i
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.gradient import NullType
from theano.compile.profiling import ScanProfileStats from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils from theano.scan_module import scan_utils
...@@ -1487,13 +1489,18 @@ class Scan(PureOp): ...@@ -1487,13 +1489,18 @@ class Scan(PureOp):
consider_constant=wrt, consider_constant=wrt,
disconnected_inputs='ignore', disconnected_inputs='ignore',
return_disconnected='None') return_disconnected='None')
except gradient.NullTypeGradError: except gradient.NullTypeGradError, e:
# The gradient wrt that particular input is undefined. # The gradient wrt that particular input is undefined.
# This is not necessarily an issue, because maybe that # This is not necessarily an issue, because maybe that
# particular input is not in the path between the # particular input is not in the path between the
# "cost" and "wrt" of the external, initial call to grad(). # "cost" and "wrt" of the external, initial call to grad().
# We simply forward the Null gradient. # We simply return a Null gradient, forwarding the message.
gmp[x] = x gmp[x] = NullType((
"This variable is Null because the grad method on the "
"inner graph of the Scan node %s returned Null for "
"the corresponding inner input variable. The original "
"message was: %s"
% (str(self), exc_message(e))))()
rval = [gmp.get(p, None) for p in diff_inputs] rval = [gmp.get(p, None) for p in diff_inputs]
return rval return rval
...@@ -1540,8 +1547,17 @@ class Scan(PureOp): ...@@ -1540,8 +1547,17 @@ class Scan(PureOp):
for jdx in xrange(len(_dC_dinps_t)): for jdx in xrange(len(_dC_dinps_t)):
if dC_dinps_t[jdx] is None: if dC_dinps_t[jdx] is None:
dC_dinps_t[jdx] = _dC_dinps_t[jdx] dC_dinps_t[jdx] = _dC_dinps_t[jdx]
elif isinstance(dC_dinps_t[jdx].type, NullType):
# The accumulated gradient is undefined
pass
elif _dC_dinps_t[jdx]: elif _dC_dinps_t[jdx]:
dC_dinps_t[jdx] += _dC_dinps_t[jdx] if isinstance(_dC_dinps_t[jdx].type, NullType):
# The accumulated gradient is defined, but the new
# term is undefined. The whole thing has to be undefined.
dC_dinps_t[jdx] = _dC_dinps_t[jdx]
else:
dC_dinps_t[jdx] += _dC_dinps_t[jdx]
# mask inputs that get no gradients # mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)): for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]: if not dC_dinps_t[dx]:
...@@ -1563,13 +1579,21 @@ class Scan(PureOp): ...@@ -1563,13 +1579,21 @@ class Scan(PureOp):
opos = self.get_output_pos(pos) opos = self.get_output_pos(pos)
if opos >= 0: if opos >= 0:
dC_dXtm1s.append(safe_new(dC_dXts[opos])) dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if x.dtype != dC_dXts[opos].dtype: if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \ dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype) x.astype(dC_dXts[opos].dtype)
else: else:
dC_dXtm1s.append(safe_new(x)) dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s): for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1 if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType):
# The accumulated gradient is undefined
pass
elif isinstance(dC_dXtm1.type, NullType):
# The new gradient is undefined, this makes the accumulated
# gradient undefined as weell
dC_dinps_t[dx + self.n_seqs] = dC_dXtm1
else:
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op # Construct scan op
# Seqs # Seqs
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]] outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
...@@ -1646,7 +1670,8 @@ class Scan(PureOp): ...@@ -1646,7 +1670,8 @@ class Scan(PureOp):
outer_inp_mitmot.append(dC_douts[idx][::-1]) outer_inp_mitmot.append(dC_douts[idx][::-1])
mitmot_inp_taps.append([]) mitmot_inp_taps.append([])
mitmot_out_taps.append([]) mitmot_out_taps.append([])
undefined = False undefined_msg = None
through_shared = False
disconnected = True disconnected = True
for jdx in xrange(len(self.mit_mot_out_slices[idx])): for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_inp_mitmot.append(dC_dXts[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
...@@ -1656,21 +1681,33 @@ class Scan(PureOp): ...@@ -1656,21 +1681,33 @@ class Scan(PureOp):
for jdx in xrange(len(self.tap_array[idx])): for jdx in xrange(len(self.tap_array[idx])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dC_dinps_t[ins_pos]) if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
undefined_msg = dC_dinps_t[ins_pos].type.why_null
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if not disconnected_dC_dinps_t[ins_pos]: if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False disconnected = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True through_shared = True
n_mitmot_inps += 1 n_mitmot_inps += 1
ins_pos += 1 ins_pos += 1
n_mitmot_outs += 1 n_mitmot_outs += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx]) mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
mitmot_out_taps[idx].append(-self.tap_array[idx][jdx]) mitmot_out_taps[idx].append(-self.tap_array[idx][jdx])
if undefined:
type_outs.append('undefined') if undefined_msg:
type_outs.append(undefined_msg)
elif through_shared:
type_outs.append('through_shared')
elif disconnected: elif disconnected:
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
...@@ -1685,12 +1722,23 @@ class Scan(PureOp): ...@@ -1685,12 +1722,23 @@ class Scan(PureOp):
inner_inp_mitmot.append(dC_dXts[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
out_pos += 1 out_pos += 1
n_mitmot_inps += 1 n_mitmot_inps += 1
undefined = False undefined_msg = None
through_shared = False
disconnected = True disconnected = True
mitmot_inp_taps[idx + offset].append(0) mitmot_inp_taps[idx + offset].append(0)
for jdx in xrange(len(self.tap_array[idx_tap])): for jdx in xrange(len(self.tap_array[idx_tap])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
undefined_msg = dC_dinps_t[ins_pos].type.why_null
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
mitmot_inp_taps[idx + offset].append( mitmot_inp_taps[idx + offset].append(
-self.tap_array[idx_tap][jdx]) -self.tap_array[idx_tap][jdx])
mitmot_out_taps[idx].append( mitmot_out_taps[idx].append(
...@@ -1699,12 +1747,16 @@ class Scan(PureOp): ...@@ -1699,12 +1747,16 @@ class Scan(PureOp):
disconnected = False disconnected = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True through_shared = True
n_mitmot_inps += 1 n_mitmot_inps += 1
ins_pos += 1 ins_pos += 1
n_mitmot_outs += 1 n_mitmot_outs += 1
if undefined:
type_outs.append('undefined') if undefined_msg:
type_outs.append(undefined_msg)
elif through_shared:
type_outs.append('through_shared')
elif disconnected: elif disconnected:
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
...@@ -1714,19 +1766,39 @@ class Scan(PureOp): ...@@ -1714,19 +1766,39 @@ class Scan(PureOp):
for idx in xrange(self.n_sit_sot): for idx in xrange(self.n_sit_sot):
mitmot_inp_taps.append([0, 1]) mitmot_inp_taps.append([0, 1])
mitmot_out_taps.append([1]) mitmot_out_taps.append([1])
undefined = False through_shared = False
if not isinstance(dC_douts[idx + offset].type, DisconnectedType): if not isinstance(dC_douts[idx + offset].type, DisconnectedType):
outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
else: else:
outer_inp_mitmot.append( if isinstance(dC_dinps_t[ins_pos].type, NullType):
tensor.zeros(outs[idx + offset].shape, # Cannot use dC_dinps_t[ins_pos].dtype, so we use
dtype=dC_dinps_t[ins_pos].dtype)) # floatX instead, as it is a dummy value that will not
inner_out_mitmot.append(dC_dinps_t[ins_pos]) # be used anyway.
outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype=theano.config.floatX))
else:
outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype=dC_dinps_t[ins_pos].dtype))
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
tensor.zeros(diff_inputs[ins_pos].shape,
dtype=theano.config.floatX))
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True through_shared = True
if undefined:
type_outs.append('undefined') if isinstance(dC_dinps_t[ins_pos].type, NullType):
type_outs.append(dC_dinps_t[ins_pos].type.why_null)
elif through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[ins_pos]: elif disconnected_dC_dinps_t[ins_pos]:
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
...@@ -1746,24 +1818,38 @@ class Scan(PureOp): ...@@ -1746,24 +1818,38 @@ class Scan(PureOp):
inner_out_nitsot = dC_dinps_t[:self.n_seqs] inner_out_nitsot = dC_dinps_t[:self.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:] inner_out_sitsot = dC_dinps_t[ins_pos:]
for _p, vl in enumerate(inner_out_sitsot): for _p, vl in enumerate(inner_out_sitsot):
undefined = False through_shared = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]): if _sh in gof.graph.inputs([vl]):
undefined = True through_shared = True
if undefined: if isinstance(vl.type, NullType):
type_outs.append('undefined') type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
inner_out_sitsot[_p] = tensor.zeros(
diff_inputs[ins_pos + _p].shape,
dtype=theano.config.floatX)
elif through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[_p + ins_pos]: elif disconnected_dC_dinps_t[_p + ins_pos]:
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
type_outs.append('connected') type_outs.append('connected')
for _p, vl in enumerate(inner_out_nitsot): for _p, vl in enumerate(inner_out_nitsot):
undefined = False through_shared = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]): if _sh in gof.graph.inputs([vl]):
undefined = True through_shared = True
if undefined: if isinstance(vl.type, NullType):
type_outs.append('undefined') type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
inner_out_sitsot[_p] = tensor.zeros(
diff_inputs[_p].shape,
dtype=theano.config.floatX)
if through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[_p]: elif disconnected_dC_dinps_t[_p]:
type_outs.append('disconnected') type_outs.append('disconnected')
else: else:
...@@ -1832,29 +1918,36 @@ class Scan(PureOp): ...@@ -1832,29 +1918,36 @@ class Scan(PureOp):
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
zip(outputs[offset:offset + self.n_seqs], zip(outputs[offset:offset + self.n_seqs],
type_outs[offset:offset + self.n_seqs])): type_outs[offset:offset + self.n_seqs])):
if t == 'disconnected': if t == 'connected':
gradients.append(x[::-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
elif t == 'undefined': elif t == 'through_shared':
gradients.append( gradients.append(
grad_undefined(self, grad_undefined(self,
p + 1, p + 1,
inputs[p + 1], inputs[p + 1],
'Depends on a shared variable')) 'Depends on a shared variable'))
else: else:
gradients.append(x[::-1]) # t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
zip(outputs[:end], type_outs[:end])): zip(outputs[:end], type_outs[:end])):
if t == 'disconnected': if t == 'connected':
gradients.append(x[::-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
elif t == 'undefined': elif t == 'through_shared':
gradients.append( gradients.append(
grad_undefined(self, grad_undefined(self,
p + 1 + self.n_seqs, p + 1 + self.n_seqs,
inputs[p + 1 + self.n_seqs], inputs[p + 1 + self.n_seqs],
'Depends on a shared variable')) 'Depends on a shared variable'))
else: else:
gradients.append(x[::-1]) # t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
start = len(gradients) start = len(gradients)
node = outs[0].owner node = outs[0].owner
...@@ -1880,16 +1973,20 @@ class Scan(PureOp): ...@@ -1880,16 +1973,20 @@ class Scan(PureOp):
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
zip(outputs[begin:end], type_outs[begin:end])): zip(outputs[begin:end], type_outs[begin:end])):
if t == 'disconnected': if t == 'connected':
gradients.append(x[-1])
elif t == 'disconnected':
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
elif t == 'undefined': elif t == 'through_shared':
gradients.append( gradients.append(
grad_undefined(self, grad_undefined(self,
p + begin + 1, p + begin + 1,
inputs[p + begin + 1], inputs[p + begin + 1],
'Depends on a shared variable')) 'Depends on a shared variable'))
else: else:
gradients.append(x[-1]) # t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
# Mask disconnected gradients # Mask disconnected gradients
# Ideally we would want to assert that the gradients we are # Ideally we would want to assert that the gradients we are
# replacing do indeed evaluate to 0, though that is not practical # replacing do indeed evaluate to 0, though that is not practical
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论