提交 bd8864ee authored 作者: nouiz's avatar nouiz

Merge pull request #760 from nouiz/fix_test

fix test following raname of Env.
...@@ -3293,29 +3293,29 @@ class T_local_sum(unittest.TestCase): ...@@ -3293,29 +3293,29 @@ class T_local_sum(unittest.TestCase):
optimizer = optdb.query(self.mode._optimizer) optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True, True))() x = T.TensorType('int64', (True, True, True))()
env = Env([x], [x.sum()]) g = FunctionGraph([x], [x.sum()])
optimizer.optimize(env) optimizer.optimize(g)
assert not any([ assert not any([
isinstance(node.op, T.CAReduce) isinstance(node.op, T.CAReduce)
for node in env.toposort()]) for node in g.toposort()])
def test_local_sum_broadcast_all_1(self): def test_local_sum_broadcast_all_1(self):
optimizer = optdb.query(self.mode._optimizer) optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True))() x = T.TensorType('int64', (True, True))()
env = Env([x], [x.sum(axis=[0, 1])]) g = FunctionGraph([x], [x.sum(axis=[0, 1])])
optimizer.optimize(env) optimizer.optimize(g)
assert not any([ assert not any([
isinstance(node.op, T.CAReduce) isinstance(node.op, T.CAReduce)
for node in env.toposort()]) for node in g.toposort()])
def test_local_sum_broadcast_some_0(self): def test_local_sum_broadcast_some_0(self):
optimizer = optdb.query(self.mode._optimizer) optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))() x = T.TensorType('int64', (True, False, True))()
env = Env([x], [x.sum(axis=[0, 1])]) g = FunctionGraph([x], [x.sum(axis=[0, 1])])
optimizer.optimize(env) optimizer.optimize(g)
order = env.toposort() order = g.toposort()
assert 1 == sum([isinstance(node.op, T.CAReduce) for node in order]) assert 1 == sum([isinstance(node.op, T.CAReduce) for node in order])
op = order[-2].op op = order[-2].op
assert isinstance(op, T.CAReduce) assert isinstance(op, T.CAReduce)
...@@ -3329,9 +3329,9 @@ class T_local_sum(unittest.TestCase): ...@@ -3329,9 +3329,9 @@ class T_local_sum(unittest.TestCase):
optimizer = optdb.query(self.mode._optimizer) optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))() x = T.TensorType('int64', (True, False, True))()
env = Env([x], [x.sum(axis=[0, 2])]) g = FunctionGraph([x], [x.sum(axis=[0, 2])])
optimizer.optimize(env) optimizer.optimize(g)
order = env.toposort() order = g.toposort()
assert 0 == sum([isinstance(node.op, T.CAReduce) for node in order]) assert 0 == sum([isinstance(node.op, T.CAReduce) for node in order])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论