提交 9a67255b authored 作者: Frederic Bastien's avatar Frederic Bastien

fix tests in FAST_COMPILE mode.

上级 a6b07409
......@@ -3317,12 +3317,16 @@ class T_local_sum(unittest.TestCase):
optimizer.optimize(g)
order = g.toposort()
assert 1 == sum([isinstance(node.op, T.CAReduce) for node in order])
op = order[-2].op
if config.mode == 'FAST_COMPILE':
node = order[-1]
else:
node = order[-2]
op = node.op
assert isinstance(op, T.CAReduce)
# -- the leading broadcastable dimension has been dropped
# by the local_sum_broadcastable optimization
# now summation is over the original x's dimension 1.
assert order[-2].inputs[0].ndim == 2, order[-2]
assert node.inputs[0].ndim == 2, node
assert op.axis == (0,), op.axis
def test_local_sum_broadcast_some_1(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论