提交 8757f5f1 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Use a Counter in tests/link/test_vm

上级 cc1a1cbb
import time import time
from collections import Counter
import numpy as np import numpy as np
import pytest import pytest
...@@ -34,11 +35,10 @@ class TestCallbacks: ...@@ -34,11 +35,10 @@ class TestCallbacks:
# Test the `VMLinker`'s callback argument, which can be useful for debugging. # Test the `VMLinker`'s callback argument, which can be useful for debugging.
def setup_method(self): def setup_method(self):
self.n_callbacks = {} self.n_callbacks = Counter()
def callback(self, node, thunk, storage_map, compute_map): def callback(self, node, thunk, storage_map, compute_map):
key = node.op.__class__.__name__ key = node.op.__class__.__name__
self.n_callbacks.setdefault(key, 0)
self.n_callbacks[key] += 1 self.n_callbacks[key] += 1
def test_callback(self): def test_callback(self):
...@@ -50,9 +50,9 @@ class TestCallbacks: ...@@ -50,9 +50,9 @@ class TestCallbacks:
) )
f(1, 2, 3) f(1, 2, 3)
assert sum(self.n_callbacks.values()) == len(f.maker.fgraph.toposort()) assert self.n_callbacks.total() == len(f.maker.fgraph.toposort())
f(1, 2, 3) f(1, 2, 3)
assert sum(self.n_callbacks.values()) == len(f.maker.fgraph.toposort()) * 2 assert self.n_callbacks.total() == len(f.maker.fgraph.toposort()) * 2
def test_callback_with_ifelse(self): def test_callback_with_ifelse(self):
a, b, c = scalars("abc") a, b, c = scalars("abc")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论