提交 ba16d2d6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to tests.sandbox

上级 044f52ec
......@@ -45,7 +45,7 @@ def test_rop_lop():
v1 = rop_f(vx, vv)
v2 = scan_f(vx, vv)
assert _allclose(v1, v2), "ROP mismatch: %s %s" % (v1, v2)
assert _allclose(v1, v2), "ROP mismatch: {} {}".format(v1, v2)
raised = False
try:
......@@ -54,10 +54,8 @@ def test_rop_lop():
raised = True
if not raised:
raise Exception(
(
"Op did not raised an error even though the function"
" is not differentiable"
)
"Op did not raised an error even though the function"
" is not differentiable"
)
vv = np.asarray(rng.uniform(size=(4,)), theano.config.floatX)
......@@ -69,7 +67,7 @@ def test_rop_lop():
v1 = lop_f(vx, vv)
v2 = scan_f(vx, vv)
assert _allclose(v1, v2), "LOP mismatch: %s %s" % (v1, v2)
assert _allclose(v1, v2), "LOP mismatch: {} {}".format(v1, v2)
def test_spectral_radius_bound():
......
......@@ -217,7 +217,9 @@ def check_basics(
if hasattr(target_avg, "shape"): # looks if target_avg is an array
diff = np.mean(abs(mean - target_avg))
# print prefix, 'mean diff with mean', diff
assert np.all(diff < mean_rtol * (1 + abs(target_avg))), "bad mean? %s %s" % (
assert np.all(
diff < mean_rtol * (1 + abs(target_avg))
), "bad mean? {} {}".format(
mean,
target_avg,
)
......@@ -228,7 +230,7 @@ def check_basics(
# print prefix, 'mean', mean
assert abs(mean - target_avg) < mean_rtol * (
1 + abs(target_avg)
), "bad mean? %f %f" % (mean, target_avg)
), "bad mean? {:f} {:f}".format(mean, target_avg)
std = np.sqrt(avg_var)
# print prefix, 'var', avg_var
......@@ -236,7 +238,7 @@ def check_basics(
if target_std is not None:
assert abs(std - target_std) < std_tol * (
1 + abs(target_std)
), "bad std? %f %f %f" % (std, target_std, std_tol)
), "bad std? {:f} {:f} {:f}".format(std, target_std, std_tol)
# print prefix, 'time', dt
# print prefix, 'elements', steps * sample_size[0] * sample_size[1]
# print prefix, 'samples/sec', steps * sample_size[0] * sample_size[1] / dt
......@@ -598,11 +600,11 @@ def test_normal_truncation():
# check if truncated at 2*std
samples = f(*input)
assert np.all(avg + 2 * std - samples >= 0), "bad upper bound? %s %s" % (
assert np.all(avg + 2 * std - samples >= 0), "bad upper bound? {} {}".format(
samples,
avg + 2 * std,
)
assert np.all(samples - (avg - 2 * std) >= 0), "bad lower bound? %s %s" % (
assert np.all(samples - (avg - 2 * std) >= 0), "bad lower bound? {} {}".format(
samples,
avg - 2 * std,
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论