提交 e183ab3b authored 作者: Tegan Maharaj's avatar Tegan Maharaj

added proper comparison to conv2d.grad

上级 3b740511
......@@ -1606,6 +1606,7 @@ class TestConv2dGrads(unittest.TestCase):
self.filter_flip = [True, False]
self.output_grad = T.tensor4()
self.output_grad_wrt = T.tensor4()
self.filters = T.tensor4()
self.x = T.tensor4('x', theano.config.floatX) #inputs
......@@ -1628,20 +1629,32 @@ class TestConv2dGrads(unittest.TestCase):
results are the same.
"""
for (in_shape, fltr_shape) in zip(inputs_shapes, filters_shapes):
for bm in border_modes:
for ss in subsamples:
for ff in filter_flip:
conv_out = T.nnet.conv.conv2d(x,
filters = filters,
for (in_shape, fltr_shape) in zip(self.inputs_shapes, self.filters_shapes):
for bm in self.border_modes:
for ss in self.subsamples:
for ff in self.filter_flip:
if filter_flip = True:
fltr_shape = transpose(fltr_shape) #conv2d doesn't seem to have filter_flip
conv_out = T.nnet.conv.conv2d(self.x,
filters = self.filters,
border_mode = bm,
subsample = ss,
image_shape = in_shape
filter_shape = fltr_shape
)
conv_grad = theano.grad(wrt=[x], known_grads={conv_out: output_grad})
conv_grad = theano.grad(wrt=[x], known_grads={conv_out: self.output_grad})
f_prime = theano.function([x, output_grad, filters], conv_grad)
utt.assert_allclose(conv_grad, f_prime)
conv_wrt_i_out = T.nnet.conv.abstract_conv.conv2d_grad_wrt_inputs(self.output_grad_wrt,
filters = self.filters,
border_mode = bm,
subsample = ss,
input_shape = in_shape,
filter_shape = fltr_shape,
filter_flip = ff
)
f = theano.function([x, output_grad_wrt, filters], conv_wrt_i_out)
utt.assert_allclose(f, f_prime)
def test_conv2d_grad_wrt_weights():
"""Compares calculated abstract grads wrt weights with the fwd grads
......@@ -1652,47 +1665,34 @@ class TestConv2dGrads(unittest.TestCase):
"""
for (in_shape, fltr_shape) in zip(inputs_shapes, filters_shapes):
for bm in border_modes:
for ss in subsamples:
for ff in filter_flip:
conv_out = T.nnet.conv.conv2d(w,
filters = filters,
for (in_shape, fltr_shape) in zip(self.inputs_shapes, self.filters_shapes):
for bm in self.border_modes:
for ss in self.subsamples:
for ff in self.filter_flip:
if filter_flip = True:
fltr_shape = transpose(fltr_shape) #conv2d doesn't seem to have filter_flip
conv_out = T.nnet.conv.conv2d(self.w,
filters = self.filters,
border_mode = bm,
subsample = ss,
image_shape = in_shape
image_shape = in_shape,
filter_shape = fltr_shape
)
conv_grad = theano.grad(wrt=[w], known_grads={conv_out: output_grad})
f_prime = theano.function([w, output_grad, filters], conv_grad)
utt.assert_allclose(conv_grad, f_prime)
def test_conv2_grads_wrt_input_and_weights():
"""Compares calculated abstract grads wrt [inputs, weights] with the fwd grads
conv_wrt_w_out = T.nnet.conv.abstract_conv.conv2d_grad_wrt_weights(self.output_grad_wrt,
filters = self.filters,
border_mode = bm,
subsample = ss,
input_shape = in_shape,
filter_shape = fltr_shape,
filter_flip = ff
)
f = theano.function([w, output_grad_wrt, filters], conv_wrt_w_out)
utt.assert_allclose(f, f_prime)
This method checks the outputs of conv2_grad wrt inputs and weights
against the outputs of T.nnet.conv forward grads to make sure the
results are the same.
"""
#for (input_shape, filter_shape) in zip(inputs_shapes, filter_shapes):
#make_rand_inputs = theano.function([x], random_stream(input_shape))
#make_rand_filters = theano.function([filters], random_stream(filter_shape))
for (in_shape, fltr_shape) in zip(inputs_shapes, filters_shapes):
for wrt in [x,w]: #for both inputs (x) and weights (w)
for bm in border_modes:
for ss in subsamples:
for ff in filter_flip:
conv_out = T.nnet.conv.conv2d(wrt,
filters = filters,
border_mode = bm,
subsample = ss,
image_shape = in_shape
filter_shape = fltr_shape
)
conv_grad = theano.grad(wrt=[wrt], known_grads={conv_out: output_grad})
f_prime = theano.function([wrt, output_grad, filters], conv_grad)
utt.assert_allclose(conv_grad, f_prime)
def test_conv2_grads_wrt_input_and_weights():
test_conv2d_grad_wrt_inputs()
test_conv2d_grad_wrt_weights()
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论