提交 173826a7 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5096 from nouiz/fast_compile

fix tests in mode=FAST_COMPILE
...@@ -63,7 +63,8 @@ class TestCorr3DMM(unittest.TestCase): ...@@ -63,7 +63,8 @@ class TestCorr3DMM(unittest.TestCase):
filter_dilation=filter_dilation, filter_dilation=filter_dilation,
subsample=subsample), subsample=subsample),
[inputs_val.transpose(0, 4, 1, 2, 3), [inputs_val.transpose(0, 4, 1, 2, 3),
filters_val.transpose(0, 4, 1, 2, 3)]) filters_val.transpose(0, 4, 1, 2, 3)],
mode=mode_with_gpu)
def test_valid(self): def test_valid(self):
self.run_conv_valid(inputs_shape=(16, 20, 12, 16, 1), self.run_conv_valid(inputs_shape=(16, 20, 12, 16, 1),
...@@ -146,12 +147,12 @@ class TestCorr3DMM(unittest.TestCase): ...@@ -146,12 +147,12 @@ class TestCorr3DMM(unittest.TestCase):
conv_gemm = GpuCorr3dMM_gradWeights(subsample=subsample)( conv_gemm = GpuCorr3dMM_gradWeights(subsample=subsample)(
img, topgrad) img, topgrad)
else: else:
conv_ref = GpuCorr3dMM_gradWeights(subsample=subsample)( conv_ref = Corr3dMM_gradWeights(subsample=subsample)(
img, topgrad, shape=filters_shape[1:4]) img, topgrad, shape=filters_shape[1:4])
conv_gemm = GpuCorr3dMM_gradWeights(subsample=subsample)( conv_gemm = GpuCorr3dMM_gradWeights(subsample=subsample)(
img, topgrad, shape=filters_shape[1:4]) img, topgrad, shape=filters_shape[1:4])
f_ref = theano.function([], conv_ref) f_ref = theano.function([], conv_ref, mode='FAST_RUN')
f = theano.function([], conv_gemm, mode=mode_with_gpu) f = theano.function([], conv_gemm, mode=mode_with_gpu)
res_ref = f_ref() res_ref = f_ref()
...@@ -205,7 +206,7 @@ class TestCorr3DMM(unittest.TestCase): ...@@ -205,7 +206,7 @@ class TestCorr3DMM(unittest.TestCase):
kern=weight, topgrad=top, kern=weight, topgrad=top,
shape=bottom_shape) shape=bottom_shape)
f_ref = theano.function([], conv_ref) f_ref = theano.function([], conv_ref, mode='FAST_RUN')
f = theano.function([], conv_gemm, mode=mode_with_gpu) f = theano.function([], conv_gemm, mode=mode_with_gpu)
res_ref = f_ref() res_ref = f_ref()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论