提交 525c21c2 authored 作者: João Victor Risso's avatar João Victor Risso

Use functor to execute verify_grad on spatial transformer Op

上级 fa790f36
......@@ -2511,4 +2511,8 @@ def test_dnn_spatialtf_grad():
assert any([isinstance(node.op, dnn.GpuDnnTransformerGradT)
for node in grad_fn.maker.fgraph.toposort()])
utt.verify_grad(dnn.dnn_spatialtf(t_img, theta), [img], mode=mode_with_gpu)
\ No newline at end of file
def fn_wrt_i(img, theta):
op = dnn.dnn_spatialtf(img, theta)
return op
utt.verify_grad(fn_wrt_i, [img, theta], mode=mode_with_gpu)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论