提交 b07912fc authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: Pascal Lamblin

tests for convolution R op

Conflicts: theano/tests/test_rop.py
上级 a1c0d613
......@@ -21,6 +21,8 @@ import numpy
from theano.gof import Op, Apply
from theano.gradient import grad_undefined
from numpy.testing.noseclasses import KnownFailureTest
from theano.tensor.signal.downsample import DownsampleFactorMax
from theano.tensor.nnet import conv
'''
Special Op created to test what happens when you have one op that is not
......@@ -262,6 +264,54 @@ class test_RopLop(RopLop_checker):
self.x[:4].dimshuffle('x', 0), 0).sum(axis=1),
(1,))
def test_conv(self):
for border_mode in ['valid', 'full']:
image_shape = (2, 2, 4, 5)
filter_shape = (2, 2, 2, 3)
image_dim = len(image_shape)
filter_dim = len(filter_shape)
input = tensor.TensorType('float64', [False] *
image_dim)(name='input')
filters = tensor.TensorType('float64', [False] *
filter_dim)(name='filter')
ev_input = tensor.TensorType('float64', [False] *
image_dim)(name='ev_input')
ev_filters = tensor.TensorType('float64', [False] *
filter_dim)(name='ev_filters')
bsize = image_shape[0]
if image_dim != 3:
bsize = 1
nkern = filter_shape[0]
if filter_dim != 3:
nkern = 1
def sym_conv2d(input, filters):
return conv.conv2d(input, filters, border_mode=border_mode)
output = sym_conv2d(input, filters).flatten()
yv = tensor.Rop(output, [input, filters], [ev_input, ev_filters])
rop_f = function([input, filters, ev_input, ev_filters],
yv, on_unused_input='ignore')
sy, _ = theano.scan(
lambda i, y, x1, x2, v1, v2:
(tensor.grad(y[i], x1) * v1).sum() + \
(tensor.grad(y[i], x2) * v2).sum(),
sequences = tensor.arange(output.shape[0]),
non_sequences=[output, input, filters,
ev_input, ev_filters])
scan_f = function([input, filters, ev_input, ev_filters], sy,
on_unused_input='ignore')
image_data = numpy.random.random(image_shape)
filter_data = numpy.random.random(filter_shape)
ev_image_data = numpy.random.random(image_shape)
ev_filter_data = numpy.random.random(filter_shape)
v1 = rop_f(image_data, filter_data, ev_image_data,
ev_filter_data)
v2 = scan_f(image_data, filter_data, ev_image_data,
ev_filter_data)
assert numpy.allclose(v1, v2), ("Rop mismatch: %s %s" %
(v1,v2))
def test_join(self):
tv = numpy.asarray(self.rng.uniform(size=(10,)),
theano.config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论