提交 62f2ad19 authored 作者: Mohammed Affan's avatar Mohammed Affan

pass shape as tensorvariables and remove condition

上级 966a535d
......@@ -1717,7 +1717,7 @@ class Grouped_conv_noOptim(unittest.TestCase):
self.top_shape = [(5, 6, 3, 3), (4, 6, 3, 3), (3, 4, 3, 1), (2, 4, 5, 3)]
self.filter_dilation = (1, 1)
self.ref_mode = 'FAST_RUN'
if theano.config.cxx == "" or not theano.config.blas.ldflags:
if theano.config.cxx == "":
raise SkipTest("CorrMM needs cxx and blas")
def test_fwd(self):
......@@ -1772,7 +1772,7 @@ class Grouped_conv_noOptim(unittest.TestCase):
subsample=self.subsample,
filter_dilation=self.filter_dilation,
num_groups=groups)
grouped_conv_output = grouped_convgrad_op(img_sym, top_sym, kshp[-2:])
grouped_conv_output = grouped_convgrad_op(img_sym, top_sym, tensor.as_tensor_variable(kshp[-2:]))
if self.flip_filter:
grouped_conv_output = grouped_conv_output[:, :, ::-1, ::-1]
grouped_func = theano.function([img_sym, top_sym], grouped_conv_output, mode=self.mode)
......@@ -1795,7 +1795,7 @@ class Grouped_conv_noOptim(unittest.TestCase):
utt.assert_allclose(grouped_output, ref_concat_output)
def conv_gradweight(inputs_val, output_val):
return grouped_convgrad_op(inputs_val, output_val, kshp[-2:])
return grouped_convgrad_op(inputs_val, output_val, tensor.as_tensor_variable(kshp[-2:]))
utt.verify_grad(conv_gradweight,
[img, top],
......@@ -1815,9 +1815,9 @@ class Grouped_conv_noOptim(unittest.TestCase):
filter_dilation=self.filter_dilation,
num_groups=groups)
if self.flip_filter:
grouped_conv_output = grouped_convgrad_op(kern_sym[:, :, ::-1, ::-1], top_sym, imshp[-2:])
grouped_conv_output = grouped_convgrad_op(kern_sym[:, :, ::-1, ::-1], top_sym, tensor.as_tensor_variable(imshp[-2:]))
else:
grouped_conv_output = grouped_convgrad_op(kern_sym, top_sym, imshp[-2:])
grouped_conv_output = grouped_convgrad_op(kern_sym, top_sym, tensor.as_tensor_variable(imshp[-2:]))
grouped_func = theano.function([kern_sym, top_sym], grouped_conv_output, mode=self.mode)
assert any([isinstance(node.op, self.conv2d_gradi)
for node in grouped_func.maker.fgraph.toposort()])
......@@ -1838,7 +1838,7 @@ class Grouped_conv_noOptim(unittest.TestCase):
utt.assert_allclose(grouped_output, ref_concat_output)
def conv_gradinputs(filters_val, output_val):
return grouped_convgrad_op(filters_val, output_val, imshp[2:])
return grouped_convgrad_op(filters_val, output_val, tensor.as_tensor_variable(imshp[-2:]))
utt.verify_grad(conv_gradinputs,
[kern, top],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论