提交 8757f5f1 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Use a Counter in tests/link/test_vm

上级 cc1a1cbb
import time
from collections import Counter
import numpy as np
import pytest
......@@ -34,11 +35,10 @@ class TestCallbacks:
# Test the `VMLinker`'s callback argument, which can be useful for debugging.
def setup_method(self):
self.n_callbacks = {}
self.n_callbacks = Counter()
def callback(self, node, thunk, storage_map, compute_map):
key = node.op.__class__.__name__
self.n_callbacks.setdefault(key, 0)
self.n_callbacks[key] += 1
def test_callback(self):
......@@ -50,9 +50,9 @@ class TestCallbacks:
)
f(1, 2, 3)
assert sum(self.n_callbacks.values()) == len(f.maker.fgraph.toposort())
assert self.n_callbacks.total() == len(f.maker.fgraph.toposort())
f(1, 2, 3)
assert sum(self.n_callbacks.values()) == len(f.maker.fgraph.toposort()) * 2
assert self.n_callbacks.total() == len(f.maker.fgraph.toposort()) * 2
def test_callback_with_ifelse(self):
a, b, c = scalars("abc")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论