提交 4ecb5a04 authored 作者: lamblin's avatar lamblin

Merge pull request #995 from pascanur/fix_grad_scan_dtype

Fix grad scan dtype
......@@ -235,13 +235,13 @@ class Scan(PureOp):
'graph of scan results in an upcast or downcast. '
'Please make sure that you use dtypes consistently')
# TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and
# TODO assert the type(dtype, ndim of self.inputs and
# inputs correspond)
#assert len(inputs) >= len(self.inputs)
#if self.info['as_while']:
# assert len(inputs) >= len(self.inputs)
# if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + \
# self.info["n_nit_sot"]
#else:
# else:
# assert len(inputs) == len(self.inputs) + 1 + \
# self.info["n_nit_sot"]
# Flags that indicate which inputs are vectors
......@@ -292,8 +292,10 @@ class Scan(PureOp):
str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
outer_mitmot[ipos + k].ndim,
str(inner_mitmot[ipos + k]),
inner_mitmot[ipos + k].type.dtype))
inner_mitmot[ipos + k].type.dtype,
inner_mitmot[ipos + k].ndim))
ipos += len(itaps)
for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos + k].type.dtype != \
......@@ -304,7 +306,9 @@ class Scan(PureOp):
(str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
inner_mitmot_outs[opos + k].type.dtype))
outer_mitmot.ndim,
inner_mitmot_outs[opos + k].type.dtype,
inner_mitmot_outs[opos + k].ndim))
opos += len(otaps)
argoffset += len(self.outer_mitmot(inputs))
# Same checks as above but for outputs of type mit_sot
......@@ -1194,6 +1198,37 @@ class Scan(PureOp):
for o, x in izip(node.outputs, scan_outs)]
return scan_outs
def get_input_pos(self, output_index):
ipos = 0
opos = output_index
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(otaps) > opos:
return ipos
else:
opos = opos - len(otaps)
ipos += len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if opos == 0:
return ipos
else:
opos = opos - 1
ipos += len(taps)
if opos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_slice_idx(self, output_index):
ipos = 0
opos = output_index
for otaps in zip(self.mitmot_out_taps()):
if len(otaps) > 0:
return ipos
else:
opos = opos - 1
ipos += len(otaps)
return ipos + opos
### GRAD FUNCTION
def grad(self, args, g_outs):
......@@ -1291,12 +1326,6 @@ class Scan(PureOp):
offset = len(args) - len(other_args) - pos
# 7.2. generate variables to represent previous steps of g_outs
for idx, diff_in in enumerate(diff_inputs):
prev_gfn_out = safe_new(diff_in)
if hasattr(diff_in, 'name') and diff_in.name:
prev_gfn_out.name = 'g_prev_' + diff_in.name
else:
prev_gfn_out.name = 'g_prev_' + str(idx)
prev_inner_gfn_outs.append(prev_gfn_out)
if idx < pos:
zeros_like_diff_ins.append(tensor.zeros_like(diff_in))
else:
......@@ -1305,12 +1334,12 @@ class Scan(PureOp):
# 7.3. compute gradients of the inputs given one output
for dx, out in enumerate(clean_outputs):
if g_outs[dx] != None:
inner_g_out = safe_new(g_outs[dx][0])
input_pos = self.get_input_pos(dx)
if input_pos >= 0:
corresponding_input = self_inputs[input_pos]
tmp = tensor.grad(out.sum(), corresponding_input)
inner_g_out = safe_new(tmp)
else:
# We do not have a gradient on this output so we need a
# placeholder, which for now has the same dtype as the
# output
inner_g_out = safe_new(out)
###
#### I need to clip the gradient HERE !!
......@@ -1328,10 +1357,7 @@ class Scan(PureOp):
grad_outs = compute_gradient(out, _g_out)
if not inner_gfn_outs:
for idx, gfn_out in enumerate(grad_outs):
if idx >= self.n_seqs:
inner_gfn_outs.append(prev_inner_gfn_outs[idx])
else:
inner_gfn_outs.append(None)
inner_gfn_outs.append(None)
# 7.4 Sum the gradients
# safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix
......@@ -1344,6 +1370,10 @@ class Scan(PureOp):
else:
inner_gfn_outs[i] = x
prev_inner_gfn_outs = [x.type() for x in inner_gfn_outs]
for dx in xrange(self.n_seqs, len(inner_gfn_outs)):
inner_gfn_outs[dx] = inner_gfn_outs[dx] + \
prev_inner_gfn_outs[dx]
## 8. Mask the outputs that are not differentiable
# backwards pass
for i in xrange(len(inner_gfn_outs)):
......@@ -1357,7 +1387,9 @@ class Scan(PureOp):
# this try is for catching non ndarray inputs (random
# states) it is more of a safety check ( all random
# states should be after n_outs_not_shared ...
g_outs[i] = tensor.zeros_like(scan_outputs[i])
g_outs[i] = tensor.cast(
tensor.zeros_like(scan_outputs[i]),
inner_gfn_outs[self.get_output_slice_idx(i)].dtype)
except Exception:
g_outs[i] = theano.tensor.constant(
numpy.array(0, theano.config.floatX))
......@@ -1493,8 +1525,9 @@ class Scan(PureOp):
n_sitsot_outs = len(prev_inner_gfn_outs[offset:])
scan_sitsot_ins = prev_inner_gfn_outs[offset:]
scan_sitsot_init = []
for x in zeros_like_diff_ins[offset:]:
shapes = [x.shape[i] for i in xrange(x.ndim)]
for x, y in zip(prev_inner_gfn_outs[offset:],
zeros_like_diff_ins[offset:]):
shapes = [y.shape[i] for i in xrange(x.ndim)]
empty = tensor.zeros([do_steps + 1] + shapes,
dtype=x.dtype)
scan_sitsot_init.append(empty)
......
......@@ -513,7 +513,7 @@ class T_Scan(unittest.TestCase):
def f_rnn(u_t, x_tm1, W_in, W):
return (u_t * W_in + x_tm1 * W,
tensor.cast(u_t+x_tm1, 'int64'))
tensor.cast(u_t + x_tm1, 'int64'))
u = theano.tensor.fvector('u')
x0 = theano.tensor.fscalar('x0')
......@@ -561,7 +561,6 @@ class T_Scan(unittest.TestCase):
scan_node = scan_node[0]
assert scan_node.op.gpu
# simple rnn, one input, one state, weights for each; input/state
# are vectors, weights are scalars; using shared variables
def test_one_sequence_one_output_weights_shared(self):
......@@ -1124,6 +1123,29 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(W1.get_value(), numpy_W1)
assert numpy.allclose(W2.get_value(), numpy_W2)
def test_grad_dtype_change(self):
x = tensor.fscalar('x')
y = tensor.fscalar('y')
c = tensor.iscalar('c')
def inner_fn(cond, x, y):
new_cond = tensor.cast(tensor.switch(cond, x, y), 'int32')
new_x = tensor.switch(cond, tensor.nnet.sigmoid(y * x), x)
new_y = tensor.switch(cond, y, tensor.nnet.sigmoid(x))
return new_cond, new_x, new_y
values, _ = theano.scan(
inner_fn,
outputs_info=[c, x, y],
n_steps=10,
truncate_gradient=-1,
go_backwards=False)
gX, gY = tensor.grad(values[1].sum(), [x, y])
f = theano.function([c, x, y], [gX, gY],
allow_input_downcast=True)
# Check for runtime errors
f(numpy.int32(0), numpy.float32(1.), numpy.float32(.5))
def test_simple_shared_mrg_random(self):
theano_rng = theano.sandbox.rng_mrg.MRG_RandomStreams(utt.fetch_seed())
......@@ -1874,8 +1896,8 @@ class T_Scan(unittest.TestCase):
def test_scan_extra_inputs_hessian(self):
x = theano.tensor.vector('x')
A = theano.tensor.matrix('A')
fc1 = theano.shared(0.5, name = 'fc1')
fc2 = theano.shared(0.9, name = 'fc2')
fc1 = theano.shared(0.5, name='fc1')
fc2 = theano.shared(0.9, name='fc2')
y = fc1 * theano.dot(x * x, theano.dot(A, x))
y.name = 'y'
gy = theano.tensor.grad(y, x)
......@@ -3549,7 +3571,7 @@ def test_compute_test_value():
fn=lambda u, v: u + v,
sequences=[x, y])
assert not _
z.name='z'
z.name = 'z'
# The gradient computation used to crash before 6af465e.
g = tensor.grad(z.sum(), x)
#f = theano.function([x], g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论