提交 5555bcea authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add tests for MakeVector gradient. Find and fix bug for grad of int variables.

上级 810415ed
...@@ -408,7 +408,18 @@ class MakeVector(T.Op): ...@@ -408,7 +408,18 @@ class MakeVector(T.Op):
out[0][...] = inputs out[0][...] = inputs
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
return [output_gradients[0][i] for i in xrange(len(inputs))] # If the output is of an integer dtype, no gradient shall pass
if 'int' in self.dtype:
return [None] * len(inputs)
grads = []
for i,inp in enumerate(inputs):
if 'int' in inp.dtype:
# No gradient wrt integer inputs
grads.append(None)
else:
grads.append(output_gradients[0][i])
return grads
make_vector = MakeVector() make_vector = MakeVector()
......
...@@ -2611,14 +2611,68 @@ def test_make_vector(): ...@@ -2611,14 +2611,68 @@ def test_make_vector():
i = T.iscalar() i = T.iscalar()
d = T.dscalar() d = T.dscalar()
assert opt.MakeVector(dtype="int8")(b,b).dtype=="int8" #TODO: draw random values instead. Not really important.
assert opt.MakeVector(dtype="int32")(i,b).dtype=="int32" val = {b: 2,
assert opt.MakeVector(dtype="int32")(b,i).dtype=="int32" i: -3,
assert opt.MakeVector(dtype="float64")(b,i).dtype=="float64" d: 0.7}
assert opt.MakeVector(dtype="float64")(b,d).dtype=="float64"
assert opt.MakeVector(dtype="float64")(d,i).dtype=="float64" # Should work
assert opt.MakeVector(dtype="float64")().dtype=="float64" for (dtype, inputs) in [("int8", (b,b)),
assert opt.MakeVector(dtype="int64")().dtype=="int64" ("int32", (i,b)),
("int32", (b,i)),
("float64", (b,i)),
("float64", (b,d)),
("float64", (d,i)),
("float64", ()),
("int64", ()),
]:
mv = opt.MakeVector(dtype=dtype)(*inputs)
assert mv.dtype == dtype
f = theano.function([b,i,d], mv)
f_val = f(val[b], val[i], val[d])
print 'f_val =', f_val
s = mv.sum()
gb = T.grad(s, b)
gi = T.grad(s, i)
gd = T.grad(s, d)
print 'gb =', gb
print 'gi =', gi
print 'gd =', gd
g = theano.function([b,i,d], [gb, gi, gd])
g_val = g(val[b], val[i], val[d])
print 'g_val =', g_val
if dtype.startswith('int'):
# The gradient should be 0
assert numpy.allclose(g_val, 0)
else:
for var, grval in zip((b,i,d), g_val):
float_inputs = []
if var.dtype.startswith('int'):
assert grval == 0
elif var not in inputs:
assert grval == 0
else:
float_inputs.append(var)
# Build a function that takes float_inputs, use fix values for the
# other inputs, and returns the MakeVector. Use it for verify_grad.
if float_inputs:
def fun(*fl_inputs):
f_inputs = []
for var in f_inputs:
if var in fl_inputs:
# use symbolic variable
f_inputs.append(var)
else:
# use constant value
f_inputs.append(val[var])
return opt.MakeVector(dtype=dtype)(*f_inputs)
utt.verify_grad(fun, [val[ri] for ri in float_inputs])
#should fail #should fail
for (dtype,inputs) in [("int8",(b,i)), for (dtype,inputs) in [("int8",(b,i)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论