提交 121ec69a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Finish, clean-up, and test

上级 0d4e482d
......@@ -4,6 +4,7 @@ Abstract conv interface
import logging
from six import reraise
import sys
import theano
......@@ -414,26 +415,30 @@ class BaseAbstractConv2d(Op):
'"valid", "full", "half", an integer or a pair of'
' integers'.format(border_mode))
self.imshp = tuple(imshp) if imshp else None
self.imshp = tuple(imshp) if imshp else (None,) * 4
for imshp_i in self.imshp:
if imshp_i is not None:
# Components of imshp should be constant or ints
try:
get_scalar_constant_value(imshp_i)
get_scalar_constant_value(imshp_i,
only_process_constants=True)
except NotScalarConstantError:
_logger.error("imshp should be None or "
"a tuple of constant int values")
raise
self.kshp = tuple(kshp) if kshp else None
reraise(ValueError,
ValueError("imshp should be None or a tuple of "
"constant int values"),
sys.exc_info()[2])
self.kshp = tuple(kshp) if kshp else (None,) * 4
for kshp_i in self.kshp:
if kshp_i is not None:
# Components of kshp should be constant or ints
try:
get_scalar_constant_value(kshp_i)
get_scalar_constant_value(kshp_i,
only_process_constants=True)
except NotScalarConstantError:
_logger.error("kshp should be None or "
"a tuple of constant int values")
raise
reraise(ValueError,
ValueError("kshp should be None or a tuple of "
"constant int values"),
sys.exc_info()[2])
self.border_mode = border_mode
self.filter_flip = filter_flip
......
......@@ -2,12 +2,16 @@ import numpy
import unittest
from nose.plugins.skip import SkipTest
from nose.tools import assert_raises
import theano
from theano import tensor
from theano.tests import unittest_tools as utt
from theano.tensor.nnet import corr, abstract_conv as conv
from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.nnet.abstract_conv import AbstractConv2d
from theano.tensor.nnet.abstract_conv import AbstractConv2d_gradInputs
from theano.tensor.nnet.abstract_conv import AbstractConv2d_gradWeights
from theano.tensor.nnet.conv import ConvOp
from theano.tensor.nnet.corr import (CorrMM, CorrMM_gradWeights,
CorrMM_gradInputs)
......@@ -384,6 +388,53 @@ class TestCpuConv2d(BaseTestConv2d):
filter_flip=flip)
def test_constant_shapes():
# Check that the `imshp` and `kshp` parameters of the AbstractConv Ops
# are rejected if not constant or None
dummy_t4 = tensor.ftensor4()
alloc_dummy_t4 = tensor.zeros((3, 5, 7, 11), dtype='float32')
dummy_shape = tensor.lvector()
dummy_one_shape = tensor.ones(4, dtype='int64')
constant_vec_shape = tensor.constant([3, 5, 7, 11])
tuple_shape = (3, 5, 7, 11)
list_shape = list(tuple_shape)
constant_list_shape = [tensor.constant(i, dtype='int64')
for i in tuple_shape]
constant_tuple_shape = tuple(constant_list_shape)
bad_shapes = (
dummy_shape,
dummy_one_shape,
dummy_t4.shape,
alloc_dummy_t4.shape,
constant_vec_shape,
)
good_shapes = (
constant_list_shape,
constant_tuple_shape,
tuple_shape,
list_shape
)
ops_to_test = (
AbstractConv2d,
AbstractConv2d_gradInputs,
AbstractConv2d_gradWeights
)
for op in ops_to_test:
for shp in bad_shapes:
assert_raises(ValueError, op, imshp=shp)
assert_raises(ValueError, op, kshp=shp)
for shp in good_shapes:
op(imshp=shp)
op(kshp=shp)
class TestConvTypes(unittest.TestCase):
def setUp(self):
self.input = tensor.ftensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论