提交 d8a2e822 authored 作者: affanv14's avatar affanv14

change base test to better support all 4 versions

上级 c00b8cf2
...@@ -1705,6 +1705,9 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1705,6 +1705,9 @@ class Grouped_conv_noOptim(unittest.TestCase):
conv2d = staticmethod(theano.tensor.nnet.abstract_conv.AbstractConv2d) conv2d = staticmethod(theano.tensor.nnet.abstract_conv.AbstractConv2d)
conv2d_gradw = staticmethod(theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights) conv2d_gradw = staticmethod(theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights)
conv2d_gradi = staticmethod(theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs) conv2d_gradi = staticmethod(theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs)
conv2d_op = staticmethod(theano.tensor.nnet.abstract_conv.AbstractConv2d)
conv2d_gradw_op = staticmethod(theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights)
conv2d_gradi_op = staticmethod(theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs)
mode = theano.Mode(optimizer=None) mode = theano.Mode(optimizer=None)
flip_filter = False flip_filter = False
is_dnn = False is_dnn = False
...@@ -1739,8 +1742,7 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1739,8 +1742,7 @@ class Grouped_conv_noOptim(unittest.TestCase):
else: else:
grouped_conv_output = grouped_conv_op(img_sym, kern_sym) grouped_conv_output = grouped_conv_op(img_sym, kern_sym)
grouped_func = theano.function([img_sym, kern_sym], grouped_conv_output, mode=self.mode) grouped_func = theano.function([img_sym, kern_sym], grouped_conv_output, mode=self.mode)
if not self.is_dnn: assert any([isinstance(node.op, self.conv2d_op)
assert any([isinstance(node.op, self.conv2d)
for node in grouped_func.maker.fgraph.toposort()]) for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(img, kern) grouped_output = grouped_func(img, kern)
...@@ -1780,8 +1782,7 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1780,8 +1782,7 @@ class Grouped_conv_noOptim(unittest.TestCase):
if self.flip_filter: if self.flip_filter:
grouped_conv_output = grouped_conv_output[:, :, ::-1, ::-1] grouped_conv_output = grouped_conv_output[:, :, ::-1, ::-1]
grouped_func = theano.function([img_sym, top_sym], grouped_conv_output, mode=self.mode) grouped_func = theano.function([img_sym, top_sym], grouped_conv_output, mode=self.mode)
if not self.is_dnn: assert any([isinstance(node.op, self.conv2d_gradw_op)
assert any([isinstance(node.op, self.conv2d_gradw)
for node in grouped_func.maker.fgraph.toposort()]) for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(img, top) grouped_output = grouped_func(img, top)
...@@ -1827,8 +1828,7 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1827,8 +1828,7 @@ class Grouped_conv_noOptim(unittest.TestCase):
top_sym, top_sym,
tensor.as_tensor_variable(imshp if self.is_dnn else imshp[-2:])) tensor.as_tensor_variable(imshp if self.is_dnn else imshp[-2:]))
grouped_func = theano.function([kern_sym, top_sym], grouped_conv_output, mode=self.mode) grouped_func = theano.function([kern_sym, top_sym], grouped_conv_output, mode=self.mode)
if not self.is_dnn: assert any([isinstance(node.op, self.conv2d_gradi_op)
assert any([isinstance(node.op, self.conv2d_gradi)
for node in grouped_func.maker.fgraph.toposort()]) for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(kern, top) grouped_output = grouped_func(kern, top)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论