提交 83570378 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Correction of infer_shape tests

上级 11751066
...@@ -736,20 +736,35 @@ class test_Unique(utt.InferShapeTester): ...@@ -736,20 +736,35 @@ class test_Unique(utt.InferShapeTester):
Testing the infer_shape with a vector. Testing the infer_shape with a vector.
""" """
x = theano.tensor.vector() x = theano.tensor.vector()
for op in self.ops :
for op in self.ops:
if not op.return_inverse:
continue
if op.return_index :
f = op(x)[2]
else:
f = op(x)[1]
self._compile_and_check([x], self._compile_and_check([x],
[op(x)], [f],
[np.asarray(np.array([2,1,3,2]), [np.asarray(np.array([2,1,3,2]),
dtype=config.floatX)], dtype=config.floatX)],
self.op_class) self.op_class)
def test_infer_shape_matrix(self): def test_infer_shape_matrix(self):
""" """
Testing the infer_shape with a vector. Testing the infer_shape with a matrix.
""" """
x = theano.tensor.matrix() x = theano.tensor.matrix()
for op in self.ops:
if not op.return_inverse:
continue
if op.return_index :
f = op(x)[2]
else:
f = op(x)[1]
self._compile_and_check([x], self._compile_and_check([x],
[self.op(x)], [f],
[np.asarray(np.array([[2, 1], [3, 2],[2, 3]]), [np.asarray(np.array([[2, 1], [3, 2],[2, 3]]),
dtype=config.floatX)], dtype=config.floatX)],
self.op_class) self.op_class)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论