提交 47207edb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix flaky TestOpFromGraph tests

上级 34254516
......@@ -92,7 +92,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
assert np.allclose(6.0, fn(xv, yv, zv))
np.testing.assert_array_almost_equal(6.0, fn(xv, yv, zv), 4)
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
......@@ -111,8 +111,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
zv = np.ones((2, 2), dtype=config.floatX) * 5
# print function, function.__module__
# print fn.maker.fgraph.toposort()
assert np.allclose(8.0, fn(xv, yv, zv))
assert np.allclose(8.0, fn(xv, yv, zv))
np.testing.assert_array_almost_equal(8.0, fn(xv, yv, zv), 4)
np.testing.assert_array_almost_equal(8.0, fn(xv, yv, zv), 4)
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
......@@ -128,13 +128,13 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
assert np.allclose(11.0 + s.get_value(), fn(xv, yv, zv))
np.testing.assert_array_almost_equal(11.0 + s.get_value(), fn(xv, yv, zv), 4)
# grad again the shared variable
f = op(x, y, z)
f = f - grad(aet_sum(f), s)
fn = function([x, y, z], f)
assert np.allclose(15.0 + s.get_value(), fn(xv, yv, zv))
np.testing.assert_array_almost_equal(15.0 + s.get_value(), fn(xv, yv, zv), 4)
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
......@@ -162,8 +162,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
xv = np.random.rand(16).astype(config.floatX)
yv = np.random.rand(16).astype(config.floatX)
dxv, dyv = fn(xv, yv)
assert np.allclose(yv * 2, dxv)
assert np.allclose(xv * 1.5, dyv)
np.testing.assert_array_almost_equal(yv * 2, dxv, 4)
np.testing.assert_array_almost_equal(xv * 1.5, dyv, 4)
# list override case
def go1(inps, gs):
......@@ -189,9 +189,9 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
wv = np.random.rand(16).astype(config.floatX)
bv = np.random.rand(16).astype(config.floatX)
dxv, dwv, dbv = fn(xv, wv, bv)
assert np.allclose(wv * 2, dxv)
assert np.allclose(xv * 1.5, dwv)
assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
np.testing.assert_array_almost_equal(wv * 2, dxv, 4)
np.testing.assert_array_almost_equal(xv * 1.5, dwv, 4)
np.testing.assert_array_almost_equal(np.ones(16, dtype=config.floatX), dbv, 4)
# NullType and DisconnectedType
op_linear2 = cls_ofg(
......@@ -239,7 +239,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
xval = np.random.rand(32).astype(config.floatX)
y1val, y2val = fn(xval)
assert np.allclose(y1val, y2val)
np.testing.assert_array_almost_equal(y1val, y2val, 4)
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
......@@ -260,7 +260,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
duval = np.random.rand(16).astype(config.floatX)
dvval = np.dot(duval, Wval)
dvval2 = fn(xval, Wval, duval)
assert np.allclose(dvval2, dvval)
np.testing.assert_array_almost_equal(dvval2, dvval, 4)
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
......@@ -287,7 +287,9 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
fn = function([xx, yy, du, dv], dw)
vals = np.random.rand(4, 32).astype(config.floatX)
dwval = fn(*vals)
assert np.allclose(dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.0)
np.testing.assert_array_almost_equal(
dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.0, 4
)
# TODO list override case
......@@ -321,7 +323,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out = g1.eval(
{x: np.ones((5,), dtype=np.float32), y: np.ones((5,), dtype=np.float32)}
)
assert np.allclose(out, [1.0] * 5)
np.testing.assert_array_almost_equal(out, [1.0] * 5, 4)
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
......@@ -339,8 +341,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
xv = np.random.rand(16).astype(config.floatX)
yv = np.random.rand(16).astype(config.floatX)
xv2, yv2 = fn(xv, yv)
assert np.allclose(xv, xv2)
assert np.allclose(yv, yv2)
np.testing.assert_array_almost_equal(xv, xv2, 4)
np.testing.assert_array_almost_equal(yv, yv2, 4)
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论