提交 3a85ea97 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new implementation of grad method

上级 25d8089a
......@@ -34,7 +34,7 @@ from theano.gradient import DisconnectedType
from theano.compile.profiling import ScanProfileStats
import scan_utils
from scan_utils import safe_new
from scan_utils import safe_new, forced_replace
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan_op')
......@@ -1194,198 +1194,138 @@ class Scan(PureOp):
for o, x in izip(node.outputs, scan_outs)]
return scan_outs
### GRAD FUNCTION
def grad(self, args, g_outs):
# This discards information about whether incoming gradients are 0
# or disconnected from the cost
# TODO: upgrade scan op to report disconnection correctly
def strip_disconnected(g):
if isinstance(g.type, DisconnectedType):
return None
return g
g_outs = [strip_disconnected(g) for g in g_outs]
# 1. forward pass - get the outputs after applying scan
scan_outputs = self(*args)
# 2. make sure they are given as a list
if not(type(scan_outputs) in (list, tuple)):
scan_outputs = [scan_outputs]
# 3. un-group / unzip the inputs
# Note ! We don't want to use the actual same variable as the ones
# used by the original scan, rather create clones of them
def get_input_pos(self, output_index):
ipos = self.n_seqs
opos = output_index
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(otaps) > opos:
return ipos
else:
opos = opos - len(otaps)
ipos += len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if opos == 0:
return ipos
else:
opos = opos - 1
ipos += len(taps)
if opos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_pos(self, input_index):
ipos = input_index
opos = 0
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(itaps) > ipos:
return opos
else:
opos += len(otaps)
ipos -= len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if len(taps) > ipos:
return opos
else:
opos += 1
ipos -= len(taps)
if ipos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_slice_idx(self, output_index):
ipos = 0
opos = output_index
for otaps in zip(self.mitmot_out_taps()):
if len(otaps) > 0:
return ipos
else:
opos = opos - 1
ipos += len(otaps)
return ipos + opos
### GRAD FUNCTION
def grad(self, inputs, dC_douts):
outs = self(*inputs)
if not isinstance(outs, (list, tuple)):
outs = [outs]
rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs, '_grad')
self.outputs)
self_inputs = rval[0]
self_outputs = rval[1]
#differentiable inputs
diff_inputs = (self.inner_seqs(self_inputs) +
self.inner_mitmot(self_inputs) +
self.inner_mitsot(self_inputs) +
self.inner_sitsot(self_inputs) +
self.inner_non_seqs(self_inputs))
diff_outputs = (self.inner_mitmot_outs(self_outputs) +
self.inner_mitsot_outs(self_outputs) +
self.inner_sitsot_outs(self_outputs) +
self.inner_nitsot_outs(self_outputs))
seqs = self_inputs[:self.n_seqs]
offset = self.n_seqs
n_ins_mit_mot = numpy.sum([0] + [len(self.tap_array[x]) for x
in xrange(self.n_mit_mot)])
outs_mit_mot = self_inputs[offset:offset + n_ins_mit_mot]
offset += n_ins_mit_mot
n_ins_mit_sot = numpy.sum([0] + [len(self.tap_array[x]) for x
in xrange(self.n_mit_mot,
self.n_mit_mot + self.n_mit_sot)])
outs_mit_sot = self_inputs[offset:offset + n_ins_mit_sot]
offset += n_ins_mit_sot
outs_sit_sot = self_inputs[offset:offset + self.n_sit_sot]
offset += self.n_sit_sot
old_scan_shared_ins = self_inputs[offset:offset + self.n_shared_outs]
out_offset = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_nit_sot +
self.n_sit_sot)
# shared variables as well as the condition
old_scan_shared_outs = self_outputs[out_offset:]
arg_offset = (1 +
self.n_seqs +
self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot)
old_scan_init = args[arg_offset: arg_offset + self.n_shared_outs]
offset += self.n_shared_outs
other_args = self_inputs[offset:]
# 4. Collect (possibly) differentiable inputs
diff_inputs = (seqs +
outs_mit_mot +
outs_mit_sot +
outs_sit_sot +
other_args)
#args[-len(other_args):] )
# 5. construct the function that computes the gradient (we sum over
# the gradients with respect to all outputs)
def compute_gradient(y, g_y):
gmp = gradient.grad_sources_inputs(
[(y, g_y)], diff_inputs)
return [gmp.get(p, None) for p in diff_inputs]
# 6. clean the outputs (i.e. remove update rules)
end = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot)
clean_outputs = self_outputs[:end]
g_outs_no_shared = g_outs[:end]
# 7.1. empty lists to hold gradients
# List of slices from outputs (used to compute the gradients)
inner_g_outs = []
g_out_slices = []
# List of outputs of the gradient function
inner_gfn_outs = []
# slices of the input
prev_inner_gfn_outs = []
zeros_like_diff_ins = []
pos = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot)
offset = len(args) - len(other_args) - pos
# 7.2. generate variables to represent previous steps of g_outs
for idx, diff_in in enumerate(diff_inputs):
prev_gfn_out = safe_new(diff_in)
if hasattr(diff_in, 'name') and diff_in.name:
prev_gfn_out.name = 'g_prev_' + diff_in.name
else:
prev_gfn_out.name = 'g_prev_' + str(idx)
prev_inner_gfn_outs.append(prev_gfn_out)
if idx < pos:
zeros_like_diff_ins.append(tensor.zeros_like(diff_in))
else:
zeros_like_diff_ins.append(
tensor.zeros_like(args[idx + offset]))
# 7.3. compute gradients of the inputs given one output
for dx, out in enumerate(clean_outputs):
if g_outs[dx] != None:
inner_g_out = safe_new(g_outs[dx][0])
else:
# We do not have a gradient on this output so we need a
# placeholder, which for now has the same dtype as the
# output
inner_g_out = safe_new(out)
###
#### I need to clip the gradient HERE !!
if g_outs_no_shared[dx]:
g_out_slices.append(g_outs_no_shared[dx][0])
dXt_inps = [None for inp in diff_inputs]
dXtp1_dXts = []
Xts = []
for idx, Xt in enumerate(diff_outputs):
# We are looking for x[t-1] for a given x[t]
if idx >= self.n_mit_mot_outs:
Xt_placeholder = Xt.type()
Xts.append(Xt_placeholder)
Xtm1_pos = self.get_input_pos(idx)
if Xtm1_pos >= 0:
Xtm1 = self_inputs[Xtm1_pos]
# It is possible that X[t] is not actually a function of
# x[t-1], case in which we can not rely on this information
try:
tmp = tensor.grad(Xt.sum(), Xtm1)
except ValueError:
tmp = Xt
dXtp1_dXt = safe_new(tmp)
else:
g_out_slices.append(None)
if getattr(out, 'name', None) is not None:
inner_g_out.name = 'g_' + out.name
if isinstance(dC_douts[idx].type, DisconnectedType):
continue
dXtp1_dXt = safe_new(dC_douts[idx][0])
dXtp1_dXts.append(dXtp1_dXt)
_dXt_inps = compute_gradient(Xt, dXtp1_dXt)
for jdx in xrange(len(_dXt_inps)):
if dXt_inps[jdx] is None:
dXt_inps[jdx] = _dXt_inps[jdx]
elif _dXt_inps[jdx]:
dXt_inps[jdx] += _dXt_inps[jdx]
# mask inputs that get no gradients
for dx in xrange(len(dXt_inps)):
if not dXt_inps[dx]:
dXt_inps[dx] = tensor.zeros_like(diff_inputs[dx])
else:
inner_g_out.name = 'g_' + str(dx)
inner_g_outs.append(inner_g_out)
_g_out = inner_g_out
grad_outs = compute_gradient(out, _g_out)
if not inner_gfn_outs:
for idx, gfn_out in enumerate(grad_outs):
if idx >= self.n_seqs:
inner_gfn_outs.append(prev_inner_gfn_outs[idx])
else:
inner_gfn_outs.append(None)
# 7.4 Sum the gradients
# safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix
# (assume their gradient is 0)
for i, (x, y) in enumerate(zip(grad_outs, inner_gfn_outs)):
if x and y:
inner_gfn_outs[i] = x + y
elif y:
inner_gfn_outs[i] = y
else:
inner_gfn_outs[i] = x
## 8. Mask the outputs that are not differentiable
# backwards pass
for i in xrange(len(inner_gfn_outs)):
if inner_gfn_outs[i] is None:
inner_gfn_outs[i] = tensor.zeros_like(diff_inputs[i])
## 9. Mask the g_outs that are Nones :
for i, out in enumerate(scan_outputs):
if g_outs[i] is None:
try:
# this try is for catching non ndarray inputs (random
# states) it is more of a safety check ( all random
# states should be after n_outs_not_shared ...
g_outs[i] = tensor.zeros_like(scan_outputs[i])
except Exception:
g_outs[i] = theano.tensor.constant(
numpy.array(0, theano.config.floatX))
## 10. Get your sequence in order for the scan:
n_seqs = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot +
self.n_nit_sot)
offset = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot)
inner_seqs = (seqs +
outs_mit_mot +
outs_mit_sot +
outs_sit_sot +
inner_g_outs[offset:offset + self.n_nit_sot])
scan_seqs = [x[::-1] for x in args[1:self.n_seqs + 1]]
offset = 0
for Xt, Xt_placeholder in zip(
diff_outputs[self.n_mit_mot_outs:],
Xts):
tmp = forced_replace(
dXt_inps[dx],
Xt,
Xt_placeholder)
dXt_inps[dx] = tmp
# construct dX_dtm1
dXt_dXtm1s = [x.type() for x in dXt_inps[self.n_seqs:]]
for dx, dXt_dXtm1 in enumerate(dXt_dXtm1s):
dXt_inps[dx+self.n_seqs] += dXt_dXtm1
# Construct scan op
# Seqs
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = numpy.min(self.tap_array[idx])
maxtap = numpy.max(self.tap_array[idx])
seq = scan_outputs[offset + idx]
seq = outs[idx]
for k in self.tap_array[idx]:
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
if maxtap < 0:
dim_offset = abs(maxtap)
else:
......@@ -1397,124 +1337,125 @@ class Scan(PureOp):
-(maxtap - k + 1)][::-1]
else:
nw_seq = seq[dim_offset + k - mintap - 1: -1][::-1]
if getattr(seq, 'name', None) is not None:
nw_seq.name = seq.name + '[%d:]' % k
scan_seqs.append(nw_seq)
offset += self.n_mit_sot
for idx in xrange(self.n_sit_sot):
seq = scan_outputs[offset + idx][:-1]
scan_seqs.append(seq[::-1])
offset = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot)
scan_seqs += [x[::-1] for x in
g_outs[offset:offset + self.n_nit_sot]]
scan_mit_mot = []
inner_mit_mot = []
scan_mit_mot_outs = []
mit_mot_taps = []
mit_mot_out_slices = []
outer_inp_seqs.append(nw_seq)
outer_inp_seqs += [
x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
for x in self.outer_nitsot_outs(dC_douts):
if not isinstance(x.type, DisconnectedType):
outer_inp_seqs.append(x[::-1])
outer_inp_seqs += [x[::-1] for x in self.outer_mitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
inner_inp_seqs = self.inner_seqs(self_inputs)
inner_inp_seqs += self.inner_mitmot(self_inputs)
inner_inp_seqs += self.inner_mitsot(self_inputs)
inner_inp_seqs += self.inner_sitsot(self_inputs)
inner_inp_seqs += self.inner_nitsot_outs(dXtp1_dXts)
inner_inp_seqs += Xts
# mitmot
outer_inp_mitmot = []
outer_out_mitmot = []
inner_inp_mitmot = []
inner_out_mitmot = []
mitmot_inp_taps = []
mitmot_out_taps = []
out_pos = 0
ins_pos = n_seqs
n_mit_mot_outs = 0
n_mit_mot_ins = 0
ins_pos = self.n_seqs
n_mitmot_outs = 0
n_mitmot_inps = 0
for idx in xrange(self.n_mit_mot):
scan_mit_mot.append(g_outs[idx][::-1])
mit_mot_taps.append([])
mit_mot_out_slices.append([])
outer_inp_mitmot.append(dC_douts[idx][::-1])
mitmot_inp_taps.append([])
mitmot_out_taps.append([])
for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_mit_mot.append(inner_g_outs[out_pos])
mit_mot_taps[idx].append(\
-self.mit_mot_out_slices[idx][jdx])
n_mit_mot_ins += 1
inner_inp_mitmot.append(dXtp1_dXts[out_pos])
mitmot_inp_taps[idx].append(-self.mit_mot_out_slices[idx][jdx])
n_mitmot_inps += 1
out_pos += 1
for jdx in xrange(len(self.tap_array[idx])):
inner_mit_mot.append(prev_inner_gfn_outs[ins_pos])
scan_mit_mot_outs.append(\
inner_gfn_outs[ins_pos])
n_mit_mot_ins += 1
inner_inp_mitmot.append(dXt_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dXt_inps[ins_pos])
n_mitmot_inps_ += 1
ins_pos += 1
n_mit_mot_outs += 1
mit_mot_taps[idx].append(-self.tap_array[idx][jdx])
mit_mot_out_slices[idx].append(\
-self.tap_array[idx][jdx])
n_mitmot_outs += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
mitmot_out_taps[idx].append(-self.tap_array[idx][jdx])
offset = self.n_mit_mot
for idx in xrange(self.n_mit_sot):
mit_mot_taps.append([])
mit_mot_out_slices.append([])
scan_mit_mot.append(g_outs[idx + offset][::-1])
mitmot_inp_taps.append([])
mitmot_out_taps.append([])
outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
idx_tap = idx + self.n_mit_mot
inner_inp_mitmot.append(dXtp1_dXts[out_pos])
out_pos += 1
n_mitmot_inps += 1
mitmot_inp_taps[idx + offset].append(0)
for jdx in xrange(len(self.tap_array[idx_tap])):
inner_mit_mot.append(prev_inner_gfn_outs[ins_pos])
mit_mot_taps[idx + offset].append(\
inner_inp_mitmot.append(dXt_dXtm1s[ins_pos - self.n_seqs])
mitmot_inp_taps[idx + offset].append(
-self.tap_array[idx_tap][jdx])
mit_mot_out_slices[idx].append(\
mitmot_out_taps[idx].append(
-self.tap_array[idx_tap][jdx])
scan_mit_mot_outs.append(inner_gfn_outs[ins_pos])
n_mit_mot_ins += 1
inner_out_mitmot.append(dXt_inps[ins_pos])
n_mitmot_inps += 1
ins_pos += 1
n_mit_mot_outs += 1
inner_mit_mot.append(inner_g_outs[out_pos])
out_pos += 1
n_mit_mot_ins += 1
mit_mot_taps[idx + offset].append(0)
n_mitmot_outs += 1
offset += self.n_mit_sot
for idx in xrange(self.n_sit_sot):
mit_mot_taps.append([0, 1])
mit_mot_out_slices.append([1])
scan_mit_mot.append(g_outs[idx + offset][::-1])
scan_mit_mot_outs.append(inner_gfn_outs[ins_pos])
inner_mit_mot += [inner_g_outs[out_pos],
prev_inner_gfn_outs[ins_pos]]
n_mit_mot_outs += 1
mitmot_inp_taps.append([0, 1])
mitmot_out_taps.append([1])
if not isinstance(dC_douts[idx + offset].type, DisconnectedType):
outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
else:
outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape,
dtype = dXt_inps[ins_pos].dtype))
inner_out_mitmot.append(dXt_inps[ins_pos])
inner_inp_mitmot += [dXtp1_dXts[out_pos],
dXt_dXtm1s[ins_pos - self.n_seqs]]
n_mitmot_outs += 1
out_pos += 1
ins_pos += 1
n_mit_mot_ins += 2
n_nit_sot = self.n_seqs
scan_nit_sot_outs = inner_gfn_outs[:self.n_seqs]
n_mitmot_inps += 2
if self.truncate_gradient != -1:
do_steps = tensor.minimum(args[0], self.truncate_gradient)
do_steps = tensor.minimum(inputs[0], self.truncate_gradient)
else:
do_steps = args[0]
offset = (self.n_seqs +
n_ins_mit_sot +
n_ins_mit_mot +
self.n_sit_sot)
# Instead of shared outs use sit_sot
n_sitsot_outs = len(prev_inner_gfn_outs[offset:])
scan_sitsot_ins = prev_inner_gfn_outs[offset:]
scan_sitsot_init = []
for x in zeros_like_diff_ins[offset:]:
shapes = [x.shape[i] for i in xrange(x.ndim)]
empty = tensor.zeros([do_steps + 1] + shapes,
dtype=x.dtype)
scan_sitsot_init.append(empty)
scan_sitsot_outs = inner_gfn_outs[offset:]
tap_array = mit_mot_taps + [[-1] for k in
do_steps = inputs[0]
n_nit_sot = self.n_seqs
inner_out_nitsot = dXt_inps[:self.n_seqs]
inner_out_sitsot = dXt_inps[ins_pos:]
inner_inp_sitsot = dXt_dXtm1s[ins_pos - self.n_seqs:]
outer_inp_sitsot = [
tensor.zeros([do_steps + 1] +
[x.shape[i] for i in xrange(x.ndim)],
dtype = y.dtype)
for y, x in zip(inner_inp_sitsot,
self.outer_non_seqs(inputs))]
n_sitsot_outs = len(outer_inp_sitsot)
new_tap_array = mitmot_inp_taps + [[-1] for k in
xrange(n_sitsot_outs)]
info = {}
info['n_seqs'] = n_seqs
info['n_seqs'] = len(outer_inp_seqs)
info['n_mit_sot'] = 0
info['tap_array'] = tap_array
info['tap_array'] = new_tap_array
info['gpu'] = False
n_mit_mot = (self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot)
info['n_mit_mot'] = n_mit_mot
info['n_mit_mot_outs'] = n_mit_mot_outs
info['mit_mot_out_slices'] = mit_mot_out_slices
info['n_mit_mot'] = len(outer_inp_mitmot)
info['n_mit_mot_outs'] = n_mitmot_outs
info['mit_mot_out_slices'] = mitmot_out_taps
info['truncate_gradient'] = self.truncate_gradient
info['n_sit_sot'] = n_sitsot_outs
info['n_shared_outs'] = self.n_shared_outs
info['n_shared_outs'] = 0
info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while
info['profile'] = self.profile
......@@ -1524,47 +1465,30 @@ class Scan(PureOp):
else:
info['name'] = None
info['mode'] = self.mode
n_mit_sot = 0
n_sit_sot = 0
offset = (1 +
self.n_seqs +
self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot +
self.n_shared_outs)
scan_inputs = ([do_steps] +
scan_seqs +
scan_mit_mot +
scan_sitsot_init +
old_scan_init +
[args[0] for x in xrange(n_nit_sot)] +
args[offset:])
offset = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot +
self.n_shared_outs)
outer_inputs = ([do_steps] +
outer_inp_seqs +
outer_inp_mitmot +
outer_inp_sitsot +
[inputs[0] for x in xrange(n_nit_sot)] +
self.outer_shared(inputs) +
self.outer_non_seqs(inputs))
inner_other_args = self_inputs[offset:]
inner_gfn_ins = (inner_seqs +
inner_mit_mot +
scan_sitsot_ins +
old_scan_shared_ins +
inner_other_args)
inner_gfn_outs = (scan_mit_mot_outs +
scan_sitsot_outs +
scan_nit_sot_outs +
old_scan_shared_outs)
inner_gfn_ins = (inner_inp_seqs +
inner_inp_mitmot +
inner_inp_sitsot +
self.inner_shared(self_inputs) +
self.inner_non_seqs(self_inputs))
inner_gfn_outs = (inner_out_mitmot +
inner_out_sitsot +
inner_out_nitsot)
local_op = Scan(inner_gfn_ins, inner_gfn_outs, info)
outputs = local_op(*scan_inputs)
outputs = local_op(*outer_inputs)
if type(outputs) not in (list, tuple):
outputs = [outputs]
# Re-order the gradients correctly
gradients = [grad_undefined(self, 0, args[0], 'Number of steps')]
gradients = [grad_undefined(self, 0, inputs[0], 'Number of steps')]
offset = (self.n_mit_mot +
self.n_mit_sot +
......@@ -1576,12 +1500,12 @@ class Scan(PureOp):
gradients += [x[::-1] for x in outputs[:end]]
start = len(gradients)
gradients += [
grad_undefined(self, x + start, args[x + start],
grad_undefined(self, x + start, inputs[x + start],
'Shared Variable with update')
for x in xrange(self.n_shared_outs)]
start = len(gradients)
gradients += [
grad_undefined(self, x + start, args[x + start],
grad_undefined(self, x + start, inputs[x + start],
'Dimension of memory buffer for output')
for x in xrange(self.n_nit_sot)]
begin = end
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论