提交 0ab2de0c authored 作者: Michael Osthege's avatar Michael Osthege

Fix Windows compatibility of some tests

上级 c7b8ddfb
...@@ -128,7 +128,7 @@ def test_cache_versioning(): ...@@ -128,7 +128,7 @@ def test_cache_versioning():
z = my_add(x) z = my_add(x)
z_v = my_add_ver(x) z_v = my_add_ver(x)
with tempfile.TemporaryDirectory() as dir_name: with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as dir_name:
cache = ModuleCache(dir_name) cache = ModuleCache(dir_name)
lnk = CLinker().accept(FunctionGraph(outputs=[z])) lnk = CLinker().accept(FunctionGraph(outputs=[z]))
......
import os import os
import string
import subprocess import subprocess
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -37,7 +37,7 @@ class QuadraticCOpFunc(ExternalCOp): ...@@ -37,7 +37,7 @@ class QuadraticCOpFunc(ExternalCOp):
def __init__(self, a, b, c): def __init__(self, a, b, c):
super().__init__( super().__init__(
"{test_dir}/c_code/test_quadratic_function.c", "APPLY_SPECIFIC(compute_quadratic)" "{str(test_dir).replace(os.sep, "/")}/c_code/test_quadratic_function.c", "APPLY_SPECIFIC(compute_quadratic)"
) )
self.a = a self.a = a
self.b = b self.b = b
...@@ -215,9 +215,10 @@ def get_hash(modname, seed=None): ...@@ -215,9 +215,10 @@ def get_hash(modname, seed=None):
def test_ExternalCOp_c_code_cache_version(): def test_ExternalCOp_c_code_cache_version():
"""Make sure the C cache versions produced by `ExternalCOp` don't depend on `hash` seeding.""" """Make sure the C cache versions produced by `ExternalCOp` don't depend on `hash` seeding."""
with tempfile.NamedTemporaryFile(dir=".", suffix=".py") as tmp: tmp = Path() / ("".join(np.random.choice(list(string.ascii_letters), 8)) + ".py")
tmp.write(externalcop_test_code.encode()) tmp.write_bytes(externalcop_test_code.encode())
tmp.seek(0)
try:
modname = tmp.name modname = tmp.name
out_1, err1, returncode1 = get_hash(modname, seed=428) out_1, err1, returncode1 = get_hash(modname, seed=428)
out_2, err2, returncode2 = get_hash(modname, seed=3849) out_2, err2, returncode2 = get_hash(modname, seed=3849)
...@@ -225,9 +226,11 @@ def test_ExternalCOp_c_code_cache_version(): ...@@ -225,9 +226,11 @@ def test_ExternalCOp_c_code_cache_version():
assert returncode2 == 0 assert returncode2 == 0
assert err1 == err2 assert err1 == err2
hash_1, msg, _ = out_1.decode().split("\n") hash_1, msg, _ = out_1.decode().split(os.linesep)
assert msg == "__success__" assert msg == "__success__"
hash_2, msg, _ = out_2.decode().split("\n") hash_2, msg, _ = out_2.decode().split(os.linesep)
assert msg == "__success__" assert msg == "__success__"
assert hash_1 == hash_2 assert hash_1 == hash_2
finally:
tmp.unlink()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论