Add functions rescale input and use a rotation matrix

上级 e36cb332
......@@ -2306,25 +2306,39 @@ def test_dnn_spatialtf_grid_generator():
utt.seed_rng()
# shape: (num_images, channels, height, width), equivalent to NCHW
grid_dims = (10, 3, 64, 128)
grid_dims = (3, 3, 128, 128)
identity = [[1, 0, 0],
[0, 1, 0]]
identity = [[-1, 0, 0],
[0, -1, 0]]
theta = np.asarray(grid_dims[0] * [identity], dtype=theano.config.floatX)
float_type = theano.config.floatX
theta = np.asarray(grid_dims[0] * [identity], dtype=float_type)
theta_gpu = gpuarray_shared_constructor(theta)
def normalize_input(input):
# Scale input from [0, 255] to [0, 2]
scale_factor = 1. / 128.
input *= scale_factor
# Re-scale input from [0, 2] to [-1, 1] (normalized)
input -= 1
return input
def rescale_input(input):
# Re-scale output to range [0, 2]
input += 1
# Re-scale output to range [0, 255]
input *= 128
return input
from scipy import misc
f = misc.face().astype(np.float32)
f = misc.face().astype(float_type)
# Convert from HWC to CHW
f = np.transpose(f, axes=(2, 0, 1))
# Scale input from [0, 255] to [0, 2]
sc = 1. / 128.
f *= sc
# Re-scale input from [0, 2] to [-1, 1] (normalized)
f -= 1
f = normalize_input(f)
# Create array of images
img = np.asarray(grid_dims[0] * [f], dtype=theano.config.floatX)
img = np.asarray(grid_dims[0] * [f], dtype=float_type)
# Create GPU variable for the images
img_gpu = gpuarray_shared_constructor(img)
......@@ -2334,23 +2348,14 @@ def test_dnn_spatialtf_grid_generator():
result, = spatialtf_fn()
img_out = np.asarray(result, dtype=np.float32)
print(img_out.shape)
for i in range(len(img_out)):
# Re-scale output to range [0, 2]
img_out[i] += 1
# Re-scale output to range [0, 255]
img_out[i] *= 128
img_out = np.asarray(result, dtype=float_type)
# Re-scale image to range [0, 255]
img_out = rescale_input(img_out)
# Convert to uint8 (byte)
img_out = img_out.astype(dtype=np.uint8)
# Transpose back to NHWC
img_out = np.transpose(img_out, axes=(0, 2, 3, 1))
for i in range(len(img_out)):
print("[sampled image #{0}]".format(i))
print("Min/Max: {0}/{1}".format(img_out[i].min(), img_out[i].max()))
print(img_out[i])
import matplotlib.pyplot as plt
for img_idx in range(len(img_out)):
plt.imshow(img_out[img_idx])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论