提交 173826a7 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5096 from nouiz/fast_compile

fix tests in mode=FAST_COMPILE
......@@ -63,7 +63,8 @@ class TestCorr3DMM(unittest.TestCase):
filter_dilation=filter_dilation,
subsample=subsample),
[inputs_val.transpose(0, 4, 1, 2, 3),
filters_val.transpose(0, 4, 1, 2, 3)])
filters_val.transpose(0, 4, 1, 2, 3)],
mode=mode_with_gpu)
def test_valid(self):
self.run_conv_valid(inputs_shape=(16, 20, 12, 16, 1),
......@@ -146,12 +147,12 @@ class TestCorr3DMM(unittest.TestCase):
conv_gemm = GpuCorr3dMM_gradWeights(subsample=subsample)(
img, topgrad)
else:
conv_ref = GpuCorr3dMM_gradWeights(subsample=subsample)(
conv_ref = Corr3dMM_gradWeights(subsample=subsample)(
img, topgrad, shape=filters_shape[1:4])
conv_gemm = GpuCorr3dMM_gradWeights(subsample=subsample)(
img, topgrad, shape=filters_shape[1:4])
f_ref = theano.function([], conv_ref)
f_ref = theano.function([], conv_ref, mode='FAST_RUN')
f = theano.function([], conv_gemm, mode=mode_with_gpu)
res_ref = f_ref()
......@@ -205,7 +206,7 @@ class TestCorr3DMM(unittest.TestCase):
kern=weight, topgrad=top,
shape=bottom_shape)
f_ref = theano.function([], conv_ref)
f_ref = theano.function([], conv_ref, mode='FAST_RUN')
f = theano.function([], conv_gemm, mode=mode_with_gpu)
res_ref = f_ref()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论