提交 8467591f authored 作者: João Victor Risso's avatar João Victor Risso

Add basic verify_grad test to check transformer's grad wrt to inputs

上级 f5abe95e
......@@ -2480,7 +2480,7 @@ def test_dnn_spatialtf_grad():
utt.seed_rng()
# Generate random set of RGB images with 256x256 resolution (pixel values in [0, 255])
img_dims = (10, 256, 256, 3) # images are usually NHWC
img_dims = (1, 256, 256, 3) # images are usually NHWC
img = np.random.randint(low=0, high=256, size=img_dims)
# Convert from NHWC to NCHW
img = np.transpose(img, axes=(0, 3, 1, 2)).astype(theano.config.floatX)
......@@ -2510,3 +2510,5 @@ def test_dnn_spatialtf_grad():
assert any([isinstance(node.op, dnn.GpuDnnTransformerGradT)
for node in grad_fn.maker.fgraph.toposort()])
utt.verify_grad(dnn.dnn_spatialtf(t_img, theta), [img], mode=mode_with_gpu)
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论