提交 db26d266 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

assert that convolution performed in the test is as specified

上级 ebab0a06
...@@ -1736,6 +1736,8 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1736,6 +1736,8 @@ class Grouped_conv_noOptim(unittest.TestCase):
else: else:
grouped_conv_output = grouped_abstractconv_op(img_sym, kern_sym) grouped_conv_output = grouped_abstractconv_op(img_sym, kern_sym)
grouped_func = theano.function([img_sym, kern_sym], grouped_conv_output, mode=self.mode) grouped_func = theano.function([img_sym, kern_sym], grouped_conv_output, mode=self.mode)
assert any([isinstance(node.op, self.conv2d)
for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(img, kern) grouped_output = grouped_func(img, kern)
normal_conv_op = conv2d_corr(img_sym, normal_conv_op = conv2d_corr(img_sym,
...@@ -1772,6 +1774,8 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1772,6 +1774,8 @@ class Grouped_conv_noOptim(unittest.TestCase):
if self.flip_filter: if self.flip_filter:
grouped_conv_output = grouped_conv_output[:, :, ::-1, ::-1] grouped_conv_output = grouped_conv_output[:, :, ::-1, ::-1]
grouped_func = theano.function([img_sym, top_sym], grouped_conv_output, mode=self.mode) grouped_func = theano.function([img_sym, top_sym], grouped_conv_output, mode=self.mode)
assert any([isinstance(node.op, self.conv2d_gradw)
for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(img, top) grouped_output = grouped_func(img, top)
normal_conv_op = conv2d_corr_gw(img_sym, normal_conv_op = conv2d_corr_gw(img_sym,
...@@ -1814,6 +1818,8 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1814,6 +1818,8 @@ class Grouped_conv_noOptim(unittest.TestCase):
else: else:
grouped_conv_output = grouped_abstractconvgrad_op(kern_sym, top_sym, imshp[-2:]) grouped_conv_output = grouped_abstractconvgrad_op(kern_sym, top_sym, imshp[-2:])
grouped_func = theano.function([kern_sym, top_sym], grouped_conv_output, mode=self.mode) grouped_func = theano.function([kern_sym, top_sym], grouped_conv_output, mode=self.mode)
assert any([isinstance(node.op, self.conv2d_gradi)
for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(kern, top) grouped_output = grouped_func(kern, top)
normal_conv_op = conv2d_corr_gi(kern_sym, normal_conv_op = conv2d_corr_gi(kern_sym,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论