提交 65ea94c5 authored 作者: affanv14's avatar affanv14

add tests for corr3dmm group convolutions

上级 10bac9b1
...@@ -12,6 +12,7 @@ import theano ...@@ -12,6 +12,7 @@ import theano
import theano.tensor as T import theano.tensor as T
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet import corr3d, conv from theano.tensor.nnet import corr3d, conv
from theano.tensor.nnet.tests.test_abstract_conv import Grouped_conv3d_noOptim
class TestCorr3D(utt.InferShapeTester): class TestCorr3D(utt.InferShapeTester):
...@@ -418,6 +419,21 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -418,6 +419,21 @@ class TestCorr3D(utt.InferShapeTester):
self.validate((3, 1, 7, 5, 5), (2, 1, 2, 3, 3), (2, 1, 1), non_contiguous=True) self.validate((3, 1, 7, 5, 5), (2, 1, 2, 3, 3), (2, 1, 1), non_contiguous=True)
class TestGroupCorr3d(Grouped_conv3d_noOptim):
if theano.config.mode == "FAST_COMPILE":
mode = theano.compile.get_mode("FAST_RUN")
else:
mode = None
conv = corr3d.Corr3dMM
conv_gradw = corr3d.Corr3dMM_gradWeights
conv_gradi = corr3d.Corr3dMM_gradInputs
conv_op = corr3d.Corr3dMM
conv_gradw_op = corr3d.Corr3dMM_gradWeights
conv_gradi_op = corr3d.Corr3dMM_gradInputs
flip_filter = True
is_dnn = False
if __name__ == '__main__': if __name__ == '__main__':
t = TestCorr3D('setUp') t = TestCorr3D('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论