提交 be799d8f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add test for R_Op of OpFromGrah with multiple outputs

上级 53ec8e33
...@@ -310,6 +310,41 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -310,6 +310,41 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
dvval2 = fn(xval, Wval, duval) dvval2 = fn(xval, Wval, duval)
np.testing.assert_array_almost_equal(dvval2, dvval, 4) np.testing.assert_array_almost_equal(dvval2, dvval, 4)
def test_rop_multiple_outputs(self):
a = vector()
M = matrix()
b = dot(a, M)
op_matmul = OpFromGraph([a, M], [b, -b])
x = vector()
W = matrix()
du = vector()
xval = np.random.random((16,)).astype(config.floatX)
Wval = np.random.random((16, 16)).astype(config.floatX)
duval = np.random.random((16,)).astype(config.floatX)
y = op_matmul(x, W)[0]
dv = Rop(y, x, du)
fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval)
expected_dvval = np.dot(duval, Wval)
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
y = op_matmul(x, W)[1]
dv = Rop(y, x, du)
fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval)
expected_dvval = -np.dot(duval, Wval)
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
y = pt.add(*op_matmul(x, W))
dv = Rop(y, x, du)
fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval)
expected_dvval = np.zeros_like(np.dot(duval, Wval))
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论