提交 75336fb1 authored 作者: Frederic's avatar Frederic

Better conv2d error message

上级 36706769
...@@ -93,16 +93,30 @@ def conv2d(input, filters, image_shape=None, filter_shape=None, ...@@ -93,16 +93,30 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
image_shape = list(image_shape) image_shape = list(image_shape)
for i in xrange(len(image_shape)): for i in xrange(len(image_shape)):
if image_shape[i] is not None: if image_shape[i] is not None:
try:
image_shape[i] = get_scalar_constant_value( image_shape[i] = get_scalar_constant_value(
as_tensor_variable(image_shape[i])) as_tensor_variable(image_shape[i]))
except NotScalarConstantError, e:
raise NotScalarConstantError(
"The convolution need that the shape"
" information are constant values. We got"
" %s for the image_shape parameter" %
image_shape[i])
assert str(image_shape[i].dtype).startswith('int') assert str(image_shape[i].dtype).startswith('int')
image_shape[i] = int(image_shape[i]) image_shape[i] = int(image_shape[i])
if filter_shape is not None: if filter_shape is not None:
filter_shape = list(filter_shape) filter_shape = list(filter_shape)
for i in xrange(len(filter_shape)): for i in xrange(len(filter_shape)):
if filter_shape[i] is not None: if filter_shape[i] is not None:
try:
filter_shape[i] = get_scalar_constant_value( filter_shape[i] = get_scalar_constant_value(
as_tensor_variable(filter_shape[i])) as_tensor_variable(filter_shape[i]))
except NotScalarConstantError, e:
raise NotScalarConstantError(
"The convolution need that the shape"
" information are constant values. We got"
" %s for the filter_shape "
"parameter" % filter_shape[i])
assert str(filter_shape[i].dtype).startswith('int') assert str(filter_shape[i].dtype).startswith('int')
filter_shape[i] = int(filter_shape[i]) filter_shape[i] = int(filter_shape[i])
......
...@@ -7,7 +7,7 @@ from theano.tests import unittest_tools as utt ...@@ -7,7 +7,7 @@ from theano.tests import unittest_tools as utt
from theano.tensor.nnet import conv from theano.tensor.nnet import conv
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose, NotScalarConstantError
class TestConv2D(utt.InferShapeTester): class TestConv2D(utt.InferShapeTester):
...@@ -353,6 +353,20 @@ class TestConv2D(utt.InferShapeTester): ...@@ -353,6 +353,20 @@ class TestConv2D(utt.InferShapeTester):
N_image_shape=(3, 2, 8, 8), N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5)) N_filter_shape=(4, 2, 5, 5))
def test_wrong_info(self):
"""
Test convolutions when we don't give a constant as shape information
"""
i = theano.scalar.basic.int32()
self.assertRaises(NotScalarConstantError, self.validate,
(3, 2, 8, i), (4, 2, 5, 5),
N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5))
self.assertRaises(NotScalarConstantError, self.validate,
(3, 2, 8, 8), (4, 2, 5, i),
N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5))
def test_full_mode(self): def test_full_mode(self):
""" """
Tests basic convolution in full mode and case where filter Tests basic convolution in full mode and case where filter
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论