提交 3603fce6 authored 作者: Frederic's avatar Frederic

Add tests

上级 512b3eac
...@@ -5016,6 +5016,57 @@ class TestShape_i(utt.InferShapeTester): ...@@ -5016,6 +5016,57 @@ class TestShape_i(utt.InferShapeTester):
[admat_val], Shape_i) [admat_val], Shape_i)
class TestShapeFeature(unittest.TestCase):
def test_scalar(self):
x = scalar()
cst = T.constant(1).clone()
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = opt.ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
def test_vector(self):
x = vector()
cst = T.constant(1).clone()
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = opt.ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
def test_vector2(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = opt.ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o)
def test_vector_dim(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = opt.ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o, 0, 0)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o, 0, 0)
def test_vector_dim_err(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = opt.ShapeFeature()
fgraph.attach_feature(shape_feature)
self.assertRaises(IndexError, shape_feature.same_shape, x, o, 1, 0)
self.assertRaises(IndexError, shape_feature.same_shape, x, o, 0, 1)
if __name__ == '__main__': if __name__ == '__main__':
t = TestMakeVector('setUp') t = TestMakeVector('setUp')
t.setUp() t.setUp()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论