提交 3a85ea97 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new implementation of grad method

上级 25d8089a
...@@ -34,7 +34,7 @@ from theano.gradient import DisconnectedType ...@@ -34,7 +34,7 @@ from theano.gradient import DisconnectedType
from theano.compile.profiling import ScanProfileStats from theano.compile.profiling import ScanProfileStats
import scan_utils import scan_utils
from scan_utils import safe_new from scan_utils import safe_new, forced_replace
# Logging function for sending warning or info # Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan_op') _logger = logging.getLogger('theano.scan_module.scan_op')
...@@ -1194,198 +1194,138 @@ class Scan(PureOp): ...@@ -1194,198 +1194,138 @@ class Scan(PureOp):
for o, x in izip(node.outputs, scan_outs)] for o, x in izip(node.outputs, scan_outs)]
return scan_outs return scan_outs
### GRAD FUNCTION def get_input_pos(self, output_index):
def grad(self, args, g_outs): ipos = self.n_seqs
opos = output_index
# This discards information about whether incoming gradients are 0 for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
# or disconnected from the cost if len(otaps) > opos:
# TODO: upgrade scan op to report disconnection correctly return ipos
def strip_disconnected(g): else:
if isinstance(g.type, DisconnectedType): opos = opos - len(otaps)
return None ipos += len(itaps)
return g for dx, taps in enumerate(self.mitsot_taps()):
if opos == 0:
g_outs = [strip_disconnected(g) for g in g_outs] return ipos
else:
# 1. forward pass - get the outputs after applying scan opos = opos - 1
scan_outputs = self(*args) ipos += len(taps)
# 2. make sure they are given as a list if opos < self.info['n_sit_sot']:
if not(type(scan_outputs) in (list, tuple)): return ipos + opos
scan_outputs = [scan_outputs] else:
# 3. un-group / unzip the inputs return -1
# Note ! We don't want to use the actual same variable as the ones
# used by the original scan, rather create clones of them
def get_output_pos(self, input_index):
ipos = input_index
opos = 0
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(itaps) > ipos:
return opos
else:
opos += len(otaps)
ipos -= len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if len(taps) > ipos:
return opos
else:
opos += 1
ipos -= len(taps)
if ipos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_slice_idx(self, output_index):
ipos = 0
opos = output_index
for otaps in zip(self.mitmot_out_taps()):
if len(otaps) > 0:
return ipos
else:
opos = opos - 1
ipos += len(otaps)
return ipos + opos
### GRAD FUNCTION
def grad(self, inputs, dC_douts):
outs = self(*inputs)
if not isinstance(outs, (list, tuple)):
outs = [outs]
rval = scan_utils.reconstruct_graph(self.inputs, rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs, '_grad') self.outputs)
self_inputs = rval[0] self_inputs = rval[0]
self_outputs = rval[1] self_outputs = rval[1]
#differentiable inputs
diff_inputs = (self.inner_seqs(self_inputs) +
self.inner_mitmot(self_inputs) +
self.inner_mitsot(self_inputs) +
self.inner_sitsot(self_inputs) +
self.inner_non_seqs(self_inputs))
diff_outputs = (self.inner_mitmot_outs(self_outputs) +
self.inner_mitsot_outs(self_outputs) +
self.inner_sitsot_outs(self_outputs) +
self.inner_nitsot_outs(self_outputs))
seqs = self_inputs[:self.n_seqs]
offset = self.n_seqs
n_ins_mit_mot = numpy.sum([0] + [len(self.tap_array[x]) for x
in xrange(self.n_mit_mot)])
outs_mit_mot = self_inputs[offset:offset + n_ins_mit_mot]
offset += n_ins_mit_mot
n_ins_mit_sot = numpy.sum([0] + [len(self.tap_array[x]) for x
in xrange(self.n_mit_mot,
self.n_mit_mot + self.n_mit_sot)])
outs_mit_sot = self_inputs[offset:offset + n_ins_mit_sot]
offset += n_ins_mit_sot
outs_sit_sot = self_inputs[offset:offset + self.n_sit_sot]
offset += self.n_sit_sot
old_scan_shared_ins = self_inputs[offset:offset + self.n_shared_outs]
out_offset = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_nit_sot +
self.n_sit_sot)
# shared variables as well as the condition
old_scan_shared_outs = self_outputs[out_offset:]
arg_offset = (1 +
self.n_seqs +
self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot)
old_scan_init = args[arg_offset: arg_offset + self.n_shared_outs]
offset += self.n_shared_outs
other_args = self_inputs[offset:]
# 4. Collect (possibly) differentiable inputs
diff_inputs = (seqs +
outs_mit_mot +
outs_mit_sot +
outs_sit_sot +
other_args)
#args[-len(other_args):] )
# 5. construct the function that computes the gradient (we sum over
# the gradients with respect to all outputs)
def compute_gradient(y, g_y): def compute_gradient(y, g_y):
gmp = gradient.grad_sources_inputs( gmp = gradient.grad_sources_inputs(
[(y, g_y)], diff_inputs) [(y, g_y)], diff_inputs)
return [gmp.get(p, None) for p in diff_inputs] return [gmp.get(p, None) for p in diff_inputs]
dXt_inps = [None for inp in diff_inputs]
# 6. clean the outputs (i.e. remove update rules) dXtp1_dXts = []
end = (self.n_mit_mot_outs + Xts = []
self.n_mit_sot + for idx, Xt in enumerate(diff_outputs):
self.n_sit_sot + # We are looking for x[t-1] for a given x[t]
self.n_nit_sot) if idx >= self.n_mit_mot_outs:
clean_outputs = self_outputs[:end] Xt_placeholder = Xt.type()
g_outs_no_shared = g_outs[:end] Xts.append(Xt_placeholder)
# 7.1. empty lists to hold gradients Xtm1_pos = self.get_input_pos(idx)
# List of slices from outputs (used to compute the gradients) if Xtm1_pos >= 0:
inner_g_outs = [] Xtm1 = self_inputs[Xtm1_pos]
g_out_slices = [] # It is possible that X[t] is not actually a function of
# List of outputs of the gradient function # x[t-1], case in which we can not rely on this information
inner_gfn_outs = [] try:
# slices of the input tmp = tensor.grad(Xt.sum(), Xtm1)
prev_inner_gfn_outs = [] except ValueError:
zeros_like_diff_ins = [] tmp = Xt
pos = (self.n_seqs + dXtp1_dXt = safe_new(tmp)
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot)
offset = len(args) - len(other_args) - pos
# 7.2. generate variables to represent previous steps of g_outs
for idx, diff_in in enumerate(diff_inputs):
prev_gfn_out = safe_new(diff_in)
if hasattr(diff_in, 'name') and diff_in.name:
prev_gfn_out.name = 'g_prev_' + diff_in.name
else:
prev_gfn_out.name = 'g_prev_' + str(idx)
prev_inner_gfn_outs.append(prev_gfn_out)
if idx < pos:
zeros_like_diff_ins.append(tensor.zeros_like(diff_in))
else:
zeros_like_diff_ins.append(
tensor.zeros_like(args[idx + offset]))
# 7.3. compute gradients of the inputs given one output
for dx, out in enumerate(clean_outputs):
if g_outs[dx] != None:
inner_g_out = safe_new(g_outs[dx][0])
else:
# We do not have a gradient on this output so we need a
# placeholder, which for now has the same dtype as the
# output
inner_g_out = safe_new(out)
###
#### I need to clip the gradient HERE !!
if g_outs_no_shared[dx]:
g_out_slices.append(g_outs_no_shared[dx][0])
else: else:
g_out_slices.append(None) if isinstance(dC_douts[idx].type, DisconnectedType):
if getattr(out, 'name', None) is not None: continue
inner_g_out.name = 'g_' + out.name dXtp1_dXt = safe_new(dC_douts[idx][0])
dXtp1_dXts.append(dXtp1_dXt)
_dXt_inps = compute_gradient(Xt, dXtp1_dXt)
for jdx in xrange(len(_dXt_inps)):
if dXt_inps[jdx] is None:
dXt_inps[jdx] = _dXt_inps[jdx]
elif _dXt_inps[jdx]:
dXt_inps[jdx] += _dXt_inps[jdx]
# mask inputs that get no gradients
for dx in xrange(len(dXt_inps)):
if not dXt_inps[dx]:
dXt_inps[dx] = tensor.zeros_like(diff_inputs[dx])
else: else:
inner_g_out.name = 'g_' + str(dx) for Xt, Xt_placeholder in zip(
inner_g_outs.append(inner_g_out) diff_outputs[self.n_mit_mot_outs:],
_g_out = inner_g_out Xts):
grad_outs = compute_gradient(out, _g_out) tmp = forced_replace(
if not inner_gfn_outs: dXt_inps[dx],
for idx, gfn_out in enumerate(grad_outs): Xt,
if idx >= self.n_seqs: Xt_placeholder)
inner_gfn_outs.append(prev_inner_gfn_outs[idx]) dXt_inps[dx] = tmp
else:
inner_gfn_outs.append(None) # construct dX_dtm1
# 7.4 Sum the gradients dXt_dXtm1s = [x.type() for x in dXt_inps[self.n_seqs:]]
# safety check, some of this inputs might still not be for dx, dXt_dXtm1 in enumerate(dXt_dXtm1s):
# differentiable, for those we don't add them to the mix dXt_inps[dx+self.n_seqs] += dXt_dXtm1
# (assume their gradient is 0) # Construct scan op
for i, (x, y) in enumerate(zip(grad_outs, inner_gfn_outs)): # Seqs
if x and y: outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
inner_gfn_outs[i] = x + y
elif y:
inner_gfn_outs[i] = y
else:
inner_gfn_outs[i] = x
## 8. Mask the outputs that are not differentiable
# backwards pass
for i in xrange(len(inner_gfn_outs)):
if inner_gfn_outs[i] is None:
inner_gfn_outs[i] = tensor.zeros_like(diff_inputs[i])
## 9. Mask the g_outs that are Nones :
for i, out in enumerate(scan_outputs):
if g_outs[i] is None:
try:
# this try is for catching non ndarray inputs (random
# states) it is more of a safety check ( all random
# states should be after n_outs_not_shared ...
g_outs[i] = tensor.zeros_like(scan_outputs[i])
except Exception:
g_outs[i] = theano.tensor.constant(
numpy.array(0, theano.config.floatX))
## 10. Get your sequence in order for the scan:
n_seqs = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot +
self.n_nit_sot)
offset = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot)
inner_seqs = (seqs +
outs_mit_mot +
outs_mit_sot +
outs_sit_sot +
inner_g_outs[offset:offset + self.n_nit_sot])
scan_seqs = [x[::-1] for x in args[1:self.n_seqs + 1]]
offset = 0
for idx in xrange(self.n_mit_mot + self.n_mit_sot): for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = numpy.min(self.tap_array[idx]) mintap = numpy.min(self.tap_array[idx])
maxtap = numpy.max(self.tap_array[idx]) maxtap = numpy.max(self.tap_array[idx])
seq = scan_outputs[offset + idx] seq = outs[idx]
for k in self.tap_array[idx]: for k in self.tap_array[idx]:
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
if maxtap < 0: if maxtap < 0:
dim_offset = abs(maxtap) dim_offset = abs(maxtap)
else: else:
...@@ -1397,124 +1337,125 @@ class Scan(PureOp): ...@@ -1397,124 +1337,125 @@ class Scan(PureOp):
-(maxtap - k + 1)][::-1] -(maxtap - k + 1)][::-1]
else: else:
nw_seq = seq[dim_offset + k - mintap - 1: -1][::-1] nw_seq = seq[dim_offset + k - mintap - 1: -1][::-1]
if getattr(seq, 'name', None) is not None: outer_inp_seqs.append(nw_seq)
nw_seq.name = seq.name + '[%d:]' % k outer_inp_seqs += [
scan_seqs.append(nw_seq) x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
for x in self.outer_nitsot_outs(dC_douts):
offset += self.n_mit_sot if not isinstance(x.type, DisconnectedType):
for idx in xrange(self.n_sit_sot): outer_inp_seqs.append(x[::-1])
seq = scan_outputs[offset + idx][:-1]
scan_seqs.append(seq[::-1]) outer_inp_seqs += [x[::-1] for x in self.outer_mitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_sitsot_outs(outs)]
offset = (self.n_mit_mot_outs + outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
self.n_mit_sot +
self.n_sit_sot) inner_inp_seqs = self.inner_seqs(self_inputs)
scan_seqs += [x[::-1] for x in inner_inp_seqs += self.inner_mitmot(self_inputs)
g_outs[offset:offset + self.n_nit_sot]] inner_inp_seqs += self.inner_mitsot(self_inputs)
inner_inp_seqs += self.inner_sitsot(self_inputs)
scan_mit_mot = [] inner_inp_seqs += self.inner_nitsot_outs(dXtp1_dXts)
inner_mit_mot = [] inner_inp_seqs += Xts
scan_mit_mot_outs = [] # mitmot
mit_mot_taps = [] outer_inp_mitmot = []
mit_mot_out_slices = [] outer_out_mitmot = []
inner_inp_mitmot = []
inner_out_mitmot = []
mitmot_inp_taps = []
mitmot_out_taps = []
out_pos = 0 out_pos = 0
ins_pos = n_seqs
n_mit_mot_outs = 0
n_mit_mot_ins = 0
ins_pos = self.n_seqs ins_pos = self.n_seqs
n_mitmot_outs = 0
n_mitmot_inps = 0
for idx in xrange(self.n_mit_mot): for idx in xrange(self.n_mit_mot):
scan_mit_mot.append(g_outs[idx][::-1]) outer_inp_mitmot.append(dC_douts[idx][::-1])
mit_mot_taps.append([]) mitmot_inp_taps.append([])
mit_mot_out_slices.append([]) mitmot_out_taps.append([])
for jdx in xrange(len(self.mit_mot_out_slices[idx])): for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_mit_mot.append(inner_g_outs[out_pos]) inner_inp_mitmot.append(dXtp1_dXts[out_pos])
mit_mot_taps[idx].append(\ mitmot_inp_taps[idx].append(-self.mit_mot_out_slices[idx][jdx])
-self.mit_mot_out_slices[idx][jdx]) n_mitmot_inps += 1
n_mit_mot_ins += 1
out_pos += 1 out_pos += 1
for jdx in xrange(len(self.tap_array[idx])): for jdx in xrange(len(self.tap_array[idx])):
inner_mit_mot.append(prev_inner_gfn_outs[ins_pos]) inner_inp_mitmot.append(dXt_dXtm1s[ins_pos - self.n_seqs])
scan_mit_mot_outs.append(\ inner_out_mitmot.append(dXt_inps[ins_pos])
inner_gfn_outs[ins_pos]) n_mitmot_inps_ += 1
n_mit_mot_ins += 1
ins_pos += 1 ins_pos += 1
n_mit_mot_outs += 1 n_mitmot_outs += 1
mit_mot_taps[idx].append(-self.tap_array[idx][jdx]) mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
mit_mot_out_slices[idx].append(\ mitmot_out_taps[idx].append(-self.tap_array[idx][jdx])
-self.tap_array[idx][jdx])
offset = self.n_mit_mot offset = self.n_mit_mot
for idx in xrange(self.n_mit_sot): for idx in xrange(self.n_mit_sot):
mit_mot_taps.append([]) mitmot_inp_taps.append([])
mit_mot_out_slices.append([]) mitmot_out_taps.append([])
scan_mit_mot.append(g_outs[idx + offset][::-1]) outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
idx_tap = idx + self.n_mit_mot idx_tap = idx + self.n_mit_mot
inner_inp_mitmot.append(dXtp1_dXts[out_pos])
out_pos += 1
n_mitmot_inps += 1
mitmot_inp_taps[idx + offset].append(0)
for jdx in xrange(len(self.tap_array[idx_tap])): for jdx in xrange(len(self.tap_array[idx_tap])):
inner_mit_mot.append(prev_inner_gfn_outs[ins_pos]) inner_inp_mitmot.append(dXt_dXtm1s[ins_pos - self.n_seqs])
mit_mot_taps[idx + offset].append(\ mitmot_inp_taps[idx + offset].append(
-self.tap_array[idx_tap][jdx]) -self.tap_array[idx_tap][jdx])
mit_mot_out_slices[idx].append(\ mitmot_out_taps[idx].append(
-self.tap_array[idx_tap][jdx]) -self.tap_array[idx_tap][jdx])
scan_mit_mot_outs.append(inner_gfn_outs[ins_pos]) inner_out_mitmot.append(dXt_inps[ins_pos])
n_mit_mot_ins += 1 n_mitmot_inps += 1
ins_pos += 1 ins_pos += 1
n_mit_mot_outs += 1 n_mitmot_outs += 1
inner_mit_mot.append(inner_g_outs[out_pos])
out_pos += 1
n_mit_mot_ins += 1
mit_mot_taps[idx + offset].append(0)
offset += self.n_mit_sot offset += self.n_mit_sot
for idx in xrange(self.n_sit_sot): for idx in xrange(self.n_sit_sot):
mit_mot_taps.append([0, 1]) mitmot_inp_taps.append([0, 1])
mit_mot_out_slices.append([1]) mitmot_out_taps.append([1])
scan_mit_mot.append(g_outs[idx + offset][::-1]) if not isinstance(dC_douts[idx + offset].type, DisconnectedType):
scan_mit_mot_outs.append(inner_gfn_outs[ins_pos]) outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
inner_mit_mot += [inner_g_outs[out_pos], else:
prev_inner_gfn_outs[ins_pos]] outer_inp_mitmot.append(
n_mit_mot_outs += 1 tensor.zeros(outs[idx + offset].shape,
dtype = dXt_inps[ins_pos].dtype))
inner_out_mitmot.append(dXt_inps[ins_pos])
inner_inp_mitmot += [dXtp1_dXts[out_pos],
dXt_dXtm1s[ins_pos - self.n_seqs]]
n_mitmot_outs += 1
out_pos += 1 out_pos += 1
ins_pos += 1 ins_pos += 1
n_mit_mot_ins += 2 n_mitmot_inps += 2
n_nit_sot = self.n_seqs
scan_nit_sot_outs = inner_gfn_outs[:self.n_seqs]
if self.truncate_gradient != -1: if self.truncate_gradient != -1:
do_steps = tensor.minimum(args[0], self.truncate_gradient) do_steps = tensor.minimum(inputs[0], self.truncate_gradient)
else: else:
do_steps = args[0] do_steps = inputs[0]
offset = (self.n_seqs +
n_ins_mit_sot + n_nit_sot = self.n_seqs
n_ins_mit_mot + inner_out_nitsot = dXt_inps[:self.n_seqs]
self.n_sit_sot) inner_out_sitsot = dXt_inps[ins_pos:]
# Instead of shared outs use sit_sot inner_inp_sitsot = dXt_dXtm1s[ins_pos - self.n_seqs:]
n_sitsot_outs = len(prev_inner_gfn_outs[offset:]) outer_inp_sitsot = [
scan_sitsot_ins = prev_inner_gfn_outs[offset:] tensor.zeros([do_steps + 1] +
scan_sitsot_init = [] [x.shape[i] for i in xrange(x.ndim)],
for x in zeros_like_diff_ins[offset:]: dtype = y.dtype)
shapes = [x.shape[i] for i in xrange(x.ndim)] for y, x in zip(inner_inp_sitsot,
empty = tensor.zeros([do_steps + 1] + shapes, self.outer_non_seqs(inputs))]
dtype=x.dtype)
scan_sitsot_init.append(empty) n_sitsot_outs = len(outer_inp_sitsot)
scan_sitsot_outs = inner_gfn_outs[offset:] new_tap_array = mitmot_inp_taps + [[-1] for k in
tap_array = mit_mot_taps + [[-1] for k in
xrange(n_sitsot_outs)] xrange(n_sitsot_outs)]
info = {} info = {}
info['n_seqs'] = n_seqs info['n_seqs'] = len(outer_inp_seqs)
info['n_mit_sot'] = 0 info['n_mit_sot'] = 0
info['tap_array'] = tap_array info['tap_array'] = new_tap_array
info['gpu'] = False info['gpu'] = False
n_mit_mot = (self.n_mit_mot + info['n_mit_mot'] = len(outer_inp_mitmot)
self.n_mit_sot + info['n_mit_mot_outs'] = n_mitmot_outs
self.n_sit_sot) info['mit_mot_out_slices'] = mitmot_out_taps
info['n_mit_mot'] = n_mit_mot
info['n_mit_mot_outs'] = n_mit_mot_outs
info['mit_mot_out_slices'] = mit_mot_out_slices
info['truncate_gradient'] = self.truncate_gradient info['truncate_gradient'] = self.truncate_gradient
info['n_sit_sot'] = n_sitsot_outs info['n_sit_sot'] = n_sitsot_outs
info['n_shared_outs'] = self.n_shared_outs info['n_shared_outs'] = 0
info['n_nit_sot'] = n_nit_sot info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while info['as_while'] = self.as_while
info['profile'] = self.profile info['profile'] = self.profile
...@@ -1524,47 +1465,30 @@ class Scan(PureOp): ...@@ -1524,47 +1465,30 @@ class Scan(PureOp):
else: else:
info['name'] = None info['name'] = None
info['mode'] = self.mode info['mode'] = self.mode
n_mit_sot = 0
n_sit_sot = 0
offset = (1 +
self.n_seqs +
self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot +
self.n_shared_outs)
scan_inputs = ([do_steps] + outer_inputs = ([do_steps] +
scan_seqs + outer_inp_seqs +
scan_mit_mot + outer_inp_mitmot +
scan_sitsot_init + outer_inp_sitsot +
old_scan_init + [inputs[0] for x in xrange(n_nit_sot)] +
[args[0] for x in xrange(n_nit_sot)] + self.outer_shared(inputs) +
args[offset:]) self.outer_non_seqs(inputs))
offset = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot +
self.n_shared_outs)
inner_other_args = self_inputs[offset:] inner_other_args = self_inputs[offset:]
inner_gfn_ins = (inner_seqs + inner_gfn_ins = (inner_inp_seqs +
inner_mit_mot + inner_inp_mitmot +
scan_sitsot_ins + inner_inp_sitsot +
old_scan_shared_ins + self.inner_shared(self_inputs) +
inner_other_args) self.inner_non_seqs(self_inputs))
inner_gfn_outs = (scan_mit_mot_outs + inner_gfn_outs = (inner_out_mitmot +
scan_sitsot_outs + inner_out_sitsot +
scan_nit_sot_outs + inner_out_nitsot)
old_scan_shared_outs)
local_op = Scan(inner_gfn_ins, inner_gfn_outs, info) local_op = Scan(inner_gfn_ins, inner_gfn_outs, info)
outputs = local_op(*scan_inputs) outputs = local_op(*outer_inputs)
if type(outputs) not in (list, tuple): if type(outputs) not in (list, tuple):
outputs = [outputs] outputs = [outputs]
# Re-order the gradients correctly # Re-order the gradients correctly
gradients = [grad_undefined(self, 0, args[0], 'Number of steps')] gradients = [grad_undefined(self, 0, inputs[0], 'Number of steps')]
offset = (self.n_mit_mot + offset = (self.n_mit_mot +
self.n_mit_sot + self.n_mit_sot +
...@@ -1576,12 +1500,12 @@ class Scan(PureOp): ...@@ -1576,12 +1500,12 @@ class Scan(PureOp):
gradients += [x[::-1] for x in outputs[:end]] gradients += [x[::-1] for x in outputs[:end]]
start = len(gradients) start = len(gradients)
gradients += [ gradients += [
grad_undefined(self, x + start, args[x + start], grad_undefined(self, x + start, inputs[x + start],
'Shared Variable with update') 'Shared Variable with update')
for x in xrange(self.n_shared_outs)] for x in xrange(self.n_shared_outs)]
start = len(gradients) start = len(gradients)
gradients += [ gradients += [
grad_undefined(self, x + start, args[x + start], grad_undefined(self, x + start, inputs[x + start],
'Dimension of memory buffer for output') 'Dimension of memory buffer for output')
for x in xrange(self.n_nit_sot)] for x in xrange(self.n_nit_sot)]
begin = end begin = end
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论