提交 5bd45a5b authored 作者: Frederic Bastien's avatar Frederic Bastien

fix test following raname of Env.

上级 99f22814
......@@ -3293,29 +3293,29 @@ class T_local_sum(unittest.TestCase):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True, True))()
env = Env([x], [x.sum()])
optimizer.optimize(env)
g = FunctionGraph([x], [x.sum()])
optimizer.optimize(g)
assert not any([
isinstance(node.op, T.CAReduce)
for node in env.toposort()])
for node in g.toposort()])
def test_local_sum_broadcast_all_1(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, True))()
env = Env([x], [x.sum(axis=[0, 1])])
optimizer.optimize(env)
g = FunctionGraph([x], [x.sum(axis=[0, 1])])
optimizer.optimize(g)
assert not any([
isinstance(node.op, T.CAReduce)
for node in env.toposort()])
for node in g.toposort()])
def test_local_sum_broadcast_some_0(self):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))()
env = Env([x], [x.sum(axis=[0, 1])])
optimizer.optimize(env)
order = env.toposort()
g = FunctionGraph([x], [x.sum(axis=[0, 1])])
optimizer.optimize(g)
order = g.toposort()
assert 1 == sum([isinstance(node.op, T.CAReduce) for node in order])
op = order[-2].op
assert isinstance(op, T.CAReduce)
......@@ -3329,9 +3329,9 @@ class T_local_sum(unittest.TestCase):
optimizer = optdb.query(self.mode._optimizer)
x = T.TensorType('int64', (True, False, True))()
env = Env([x], [x.sum(axis=[0, 2])])
optimizer.optimize(env)
order = env.toposort()
g = FunctionGraph([x], [x.sum(axis=[0, 2])])
optimizer.optimize(g)
order = g.toposort()
assert 0 == sum([isinstance(node.op, T.CAReduce) for node in order])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论