提交 e9c9922b authored 作者: ChienliMa's avatar ChienliMa

Recover the deleted test_fill_diagonal()

上级 c195a8d7
...@@ -408,6 +408,62 @@ class TestBartlett(utt.InferShapeTester): ...@@ -408,6 +408,62 @@ class TestBartlett(utt.InferShapeTester):
self._compile_and_check([x], [self.op(x)], [0], self.op_class) self._compile_and_check([x], [self.op(x)], [0], self.op_class)
self._compile_and_check([x], [self.op(x)], [1], self.op_class) self._compile_and_check([x], [self.op(x)], [1], self.op_class)
class TestFillDiagonal(utt.InferShapeTester):
rng = numpy.random.RandomState(43)
def setUp(self):
super(TestFillDiagonal, self).setUp()
self.op_class = FillDiagonal
self.op = fill_diagonal
def test_perform(self):
x = tensor.matrix()
y = tensor.scalar()
f = function([x, y], fill_diagonal(x, y))
for shp in [(8, 8), (5, 8), (8, 5)]:
a = numpy.random.rand(*shp).astype(config.floatX)
val = numpy.cast[config.floatX](numpy.random.rand())
out = f(a, val)
# We can't use numpy.fill_diagonal as it is bugged.
assert numpy.allclose(numpy.diag(out), val)
assert (out == val).sum() == min(a.shape)
# test for 3d tensor
a = numpy.random.rand(3, 3, 3).astype(config.floatX)
x = tensor.tensor3()
y = tensor.scalar()
f = function([x, y], fill_diagonal(x, y))
val = numpy.cast[config.floatX](numpy.random.rand() + 10)
out = f(a, val)
# We can't use numpy.fill_diagonal as it is bugged.
assert out[0, 0, 0] == val
assert out[1, 1, 1] == val
assert out[2, 2, 2] == val
assert (out == val).sum() == min(a.shape)
def test_gradient(self):
utt.verify_grad(fill_diagonal, [numpy.random.rand(5, 8),
numpy.random.rand()],
n_tests=1, rng=TestFillDiagonal.rng)
utt.verify_grad(fill_diagonal, [numpy.random.rand(8, 5),
numpy.random.rand()],
n_tests=1, rng=TestFillDiagonal.rng)
def test_infer_shape(self):
z = tensor.dtensor3()
x = tensor.dmatrix()
y = tensor.dscalar()
self._compile_and_check([x, y], [self.op(x, y)],
[numpy.random.rand(8, 5),
numpy.random.rand()],
self.op_class)
self._compile_and_check([z, y], [self.op(z, y)],
#must be square when nd>2
[numpy.random.rand(8, 8, 8),
numpy.random.rand()],
self.op_class,
warn=False)
class TestFillDiagonalOffset(utt.InferShapeTester): class TestFillDiagonalOffset(utt.InferShapeTester):
...@@ -421,20 +477,17 @@ class TestFillDiagonalOffset(utt.InferShapeTester): ...@@ -421,20 +477,17 @@ class TestFillDiagonalOffset(utt.InferShapeTester):
def test_perform(self): def test_perform(self):
x = tensor.matrix() x = tensor.matrix()
y = tensor.scalar() y = tensor.scalar()
z = tensor.scalar() z = tensor.iscalar()
z_in = tensor.cast( z, "int32")
test_offset = numpy.random.randint(-5,5) test_offset = numpy.random.randint(-5,5)
f = function([x, y, z_in], fill_diagonal_offset(x, y, z_in)) f = function([x, y, z], fill_diagonal_offset(x, y, z))
for shp in [(8, 8), (5, 8), (8, 5)]: for shp in [(8, 8), (5, 8), (8, 5)]:
a = numpy.random.rand(*shp).astype(config.floatX) a = numpy.random.rand(*shp).astype(config.floatX)
val = numpy.cast[config.floatX](numpy.random.rand()) val = numpy.cast[config.floatX](numpy.random.rand())
out = f(a, val, test_offset) out = f(a, val, test_offset)
# We can't use numpy.fill_diagonal as it is bugged. # We can't use numpy.fill_diagonal as it is bugged.
#pdb.set_trace()
assert numpy.allclose(numpy.diag(out, test_offset), val) assert numpy.allclose(numpy.diag(out, test_offset), val)
#pdb.set_trace()
if test_offset >= 0: if test_offset >= 0:
assert (out == val).sum() == min( min(a.shape), assert (out == val).sum() == min( min(a.shape),
a.shape[1]-test_offset ) a.shape[1]-test_offset )
...@@ -458,9 +511,8 @@ class TestFillDiagonalOffset(utt.InferShapeTester): ...@@ -458,9 +511,8 @@ class TestFillDiagonalOffset(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
x = tensor.dmatrix() x = tensor.dmatrix()
y = tensor.dscalar() y = tensor.dscalar()
z = tensor.dscalar() z = tensor.iscalar()
z_in = tensor.cast( z, "int32") self._compile_and_check([x, y, z], [self.op(x, y, z)],
self._compile_and_check([x, y, z_in], [self.op(x, y, z_in)],
[numpy.random.rand(8, 5), [numpy.random.rand(8, 5),
numpy.random.rand(), numpy.random.rand(),
numpy.random.randint(-5,5)], numpy.random.randint(-5,5)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论