提交 11008559 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

add another test for corrmm

上级 ac2c1e5b
...@@ -427,6 +427,35 @@ class TestGroupCorr2d(Grouped_conv_noOptim): ...@@ -427,6 +427,35 @@ class TestGroupCorr2d(Grouped_conv_noOptim):
conv2d_gradi = staticmethod(corr.CorrMM_gradInputs) conv2d_gradi = staticmethod(corr.CorrMM_gradInputs)
flip_filter = True flip_filter = True
def test_graph(self):
# define common values first
groups = 3
bottom = np.random.rand(3, 6, 5, 5)
kern = np.random.rand(9, 2, 3, 3)
bottom_sym = T.tensor4('bottom')
kern_sym = T.tensor4('kern')
# grouped convolution graph
conv_group = self.conv2d(num_groups=groups)(bottom_sym, kern_sym)
gconv_func = theano.function([bottom_sym, kern_sym], conv_group)
# Graph for the normal hard way
kern_offset = kern_sym.shape[0] // groups
bottom_offset = bottom_sym.shape[1] // groups
split_conv_output = [self.conv2d()(bottom_sym[:, i * bottom_offset:(i + 1) * bottom_offset, :, :],
kern_sym[i * kern_offset:(i + 1) * kern_offset, :, :, :])
for i in range(groups)]
concatenated_output = T.concatenate(split_conv_output, axis=1)
conv_func = theano.function([bottom_sym, kern_sym], concatenated_output)
# calculate outputs for each graph
gconv_output = gconv_func(bottom, kern)
conv_output = conv_func(bottom, kern)
# compare values
utt.assert_allclose(gconv_output, conv_output)
if __name__ == '__main__': if __name__ == '__main__':
t = TestCorr2D('setUp') t = TestCorr2D('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论