Change spatialtf test to dnn_spatialtf function

上级 b1513804
......@@ -2306,20 +2306,22 @@ def test_dnn_spatialtf_grid_generator():
utt.seed_rng()
# shape: (width, height, num_feature_maps, num_images)
dims = (256, 256, 3, 2)
grid_dims = (2, 2, 1, 2)
theta = np.asarray([[[1, 0, 0], [0, 1, 0]],
[[1, 0, 0], [0, 1, 0]]], dtype=np.float32)
[[1, 0, 0], [0, 1, 0]]], dtype=theano.config.floatX)
# Create Spatial Transformer descriptor
desc = dnn.dnn_spatialtf_context(dims)
theta_gpu = gpuarray_shared_constructor(theta)
context_name = infer_context_name(desc)
theta_gpu = as_gpuarray_variable(theta, context_name=context_name)
img = np.asarray([[[[1, 2],
[3, 4]]],
[[[5, 6],
[7, 8]]]], dtype=np.int32)
img_gpu = gpuarray_shared_constructor(img)
grid_generator = dnn.dnn_spatialtf_grid(desc, dims, theta_gpu)
spatialtf = dnn.dnn_spatialtf(img_gpu, theta_gpu, grid_dims)
grid_fn = theano.function([], [grid_generator], mode=mode_with_gpu)
spatialtf_fn = theano.function([], [spatialtf], mode=mode_with_gpu)
topo = grid_fn.maker.fgraph.toposort()
topo = spatialtf_fn.maker.fgraph.toposort()
assert len([n for n in topo if isinstance(n.op, dnn.GpuDnnGridGeneratorOp)]) == 1
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论