提交 fca95a18 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

Improve tests to check grad for all three passes

上级 cdef0484
...@@ -1743,6 +1743,10 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1743,6 +1743,10 @@ class Grouped_conv_noOptim(unittest.TestCase):
utt.assert_allclose(grouped_output, normal_concat_output) utt.assert_allclose(grouped_output, normal_concat_output)
utt.verify_grad(grouped_abstractconv_func,
[img, kern],
mode=self.mode)
def test_gradweights(self): def test_gradweights(self):
img = np.random.random(self.img_shape).astype('float32') img = np.random.random(self.img_shape).astype('float32')
top = np.random.random(self.top_shape).astype('float32') top = np.random.random(self.top_shape).astype('float32')
...@@ -1775,6 +1779,13 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1775,6 +1779,13 @@ class Grouped_conv_noOptim(unittest.TestCase):
utt.assert_allclose(grouped_output, normal_concat_output) utt.assert_allclose(grouped_output, normal_concat_output)
def abstract_conv_gradweight(inputs_val, output_val):
return grouped_abstractconvgrad_func(inputs_val, output_val, self.kern_shape[-2:])
utt.verify_grad(abstract_conv_gradweight,
[img, top],
mode=self.mode, eps=1)
def test_gradinputs(self): def test_gradinputs(self):
kern = np.random.random(self.kern_shape).astype('float32') kern = np.random.random(self.kern_shape).astype('float32')
top = np.random.random(self.top_shape).astype('float32') top = np.random.random(self.top_shape).astype('float32')
...@@ -1807,3 +1818,10 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1807,3 +1818,10 @@ class Grouped_conv_noOptim(unittest.TestCase):
normal_concat_output = np.concatenate(normal_concat_output, axis=1) normal_concat_output = np.concatenate(normal_concat_output, axis=1)
utt.assert_allclose(grouped_output, normal_concat_output) utt.assert_allclose(grouped_output, normal_concat_output)
def abstract_conv_gradinputs(filters_val, output_val):
return grouped_abstractconvgrad_func(filters_val, output_val, self.img_shape[2:])
utt.verify_grad(abstract_conv_gradinputs,
[kern, top],
mode=self.mode, eps=1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论