提交 9f8d0b2d authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

small fix for tests

上级 11008559
......@@ -430,14 +430,14 @@ class TestGroupCorr2d(Grouped_conv_noOptim):
def test_graph(self):
# define common values first
groups = 3
bottom = np.random.rand(3, 6, 5, 5)
kern = np.random.rand(9, 2, 3, 3)
bottom = np.random.rand(3, 6, 5, 5).astype('float32')
kern = np.random.rand(9, 2, 3, 3).astype('float32')
bottom_sym = T.tensor4('bottom')
kern_sym = T.tensor4('kern')
# grouped convolution graph
conv_group = self.conv2d(num_groups=groups)(bottom_sym, kern_sym)
gconv_func = theano.function([bottom_sym, kern_sym], conv_group)
gconv_func = theano.function([bottom_sym, kern_sym], conv_group, mode=self.mode)
# Graph for the normal hard way
kern_offset = kern_sym.shape[0] // groups
......@@ -446,7 +446,7 @@ class TestGroupCorr2d(Grouped_conv_noOptim):
kern_sym[i * kern_offset:(i + 1) * kern_offset, :, :, :])
for i in range(groups)]
concatenated_output = T.concatenate(split_conv_output, axis=1)
conv_func = theano.function([bottom_sym, kern_sym], concatenated_output)
conv_func = theano.function([bottom_sym, kern_sym], concatenated_output, mode=self.mode)
# calculate outputs for each graph
gconv_output = gconv_func(bottom, kern)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论