提交 5c7cf36f authored 作者: abergeron's avatar abergeron

Merge pull request #3913 from lamblin/check_constant_shp

Make sure imshp and kshp are constant
import numpy
import theano
from theano.tensor.nnet.tests import test_abstract_conv
from theano.sandbox.cuda import float32_shared_constructor as gpu_shared
......@@ -82,3 +84,5 @@ class TestDnnConvTypes(test_abstract_conv.TestConvTypes):
self.input = cuda.ftensor4()
self.filters = cuda.ftensor4()
self.topgrad = cuda.ftensor4()
self.constant_tensor = cuda.CudaNdarray(
numpy.zeros((3, 5, 7, 11), dtype='float32'))
from nose.plugins.skip import SkipTest
import numpy
from theano.tensor.nnet.tests import test_abstract_conv
from ..type import GpuArrayType, gpuarray_shared_constructor
from ..type import GpuArrayType, gpuarray_shared_constructor, get_context
from ..dnn import dnn_available, GpuDnnConv, GpuDnnConvGradW, GpuDnnConvGradI
from .config import mode_with_gpu, test_ctx_name
from pygpu import gpuarray
gpu_ftensor4 = GpuArrayType(dtype='float32', broadcastable=(False,) * 4)
......@@ -43,3 +46,6 @@ class TestDnnConvTypes(test_abstract_conv.TestConvTypes):
self.input = gpu_ftensor4()
self.filters = gpu_ftensor4()
self.topgrad = gpu_ftensor4()
self.constant_tensor = gpuarray.array(
numpy.zeros((3, 5, 7, 11), dtype='float32'),
context=get_context(test_ctx_name))
......@@ -3,9 +3,13 @@ Abstract conv interface
"""
import logging
from six import reraise
import sys
import theano
from theano.tensor import as_tensor_variable, patternbroadcast
from theano.tensor import get_scalar_constant_value, NotScalarConstantError
from theano.gof import Apply, Op
from six.moves import xrange
......@@ -412,8 +416,30 @@ class BaseAbstractConv2d(Op):
'"valid", "full", "half", an integer or a pair of'
' integers'.format(border_mode))
self.imshp = tuple(imshp) if imshp else None
self.kshp = tuple(kshp) if kshp else None
self.imshp = tuple(imshp) if imshp else (None,) * 4
for imshp_i in self.imshp:
if imshp_i is not None:
# Components of imshp should be constant or ints
try:
get_scalar_constant_value(imshp_i,
only_process_constants=True)
except NotScalarConstantError:
reraise(ValueError,
ValueError("imshp should be None or a tuple of "
"constant int values"),
sys.exc_info()[2])
self.kshp = tuple(kshp) if kshp else (None,) * 4
for kshp_i in self.kshp:
if kshp_i is not None:
# Components of kshp should be constant or ints
try:
get_scalar_constant_value(kshp_i,
only_process_constants=True)
except NotScalarConstantError:
reraise(ValueError,
ValueError("kshp should be None or a tuple of "
"constant int values"),
sys.exc_info()[2])
self.border_mode = border_mode
self.filter_flip = filter_flip
......@@ -489,7 +515,11 @@ class AbstractConv2d(BaseAbstractConv2d):
filter_flip)
def make_node(self, img, kern):
# Make sure both inputs have the same Type
# Make sure both inputs are Variables with the same Type
if not isinstance(img, theano.Variable):
img = as_tensor_variable(img)
if not isinstance(kern, theano.Variable):
kern = as_tensor_variable(kern)
ktype = img.type.clone(dtype=kern.dtype,
broadcastable=kern.broadcastable)
kern = ktype.filter_variable(kern)
......@@ -614,7 +644,11 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
# Update shape/height_width
def make_node(self, img, topgrad, shape):
# Make sure both inputs have the same Type
# Make sure both inputs are Variables with the same Type
if not isinstance(img, theano.Variable):
img = as_tensor_variable(img)
if not isinstance(topgrad, theano.Variable):
topgrad = as_tensor_variable(topgrad)
gtype = img.type.clone(dtype=topgrad.dtype,
broadcastable=topgrad.broadcastable)
topgrad = gtype.filter_variable(topgrad)
......@@ -745,7 +779,11 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
# Update shape/height_width
def make_node(self, kern, topgrad, shape):
# Make sure both inputs have the same Type
# Make sure both inputs are Variables with the same Type
if not isinstance(kern, theano.Variable):
kern = as_tensor_variable(kern)
if not isinstance(topgrad, theano.Variable):
topgrad = as_tensor_variable(topgrad)
gtype = kern.type.clone(dtype=topgrad.dtype,
broadcastable=topgrad.broadcastable)
topgrad = gtype.filter_variable(topgrad)
......
......@@ -2,12 +2,16 @@ import numpy
import unittest
from nose.plugins.skip import SkipTest
from nose.tools import assert_raises
import theano
from theano import tensor
from theano.tests import unittest_tools as utt
from theano.tensor.nnet import corr, abstract_conv as conv
from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.nnet.abstract_conv import AbstractConv2d
from theano.tensor.nnet.abstract_conv import AbstractConv2d_gradInputs
from theano.tensor.nnet.abstract_conv import AbstractConv2d_gradWeights
from theano.tensor.nnet.conv import ConvOp
from theano.tensor.nnet.corr import (CorrMM, CorrMM_gradWeights,
CorrMM_gradInputs)
......@@ -389,12 +393,61 @@ class TestCpuConv2d(BaseTestConv2d):
filter_flip=flip)
def test_constant_shapes():
# Check that the `imshp` and `kshp` parameters of the AbstractConv Ops
# are rejected if not constant or None
dummy_t4 = tensor.ftensor4()
alloc_dummy_t4 = tensor.zeros((3, 5, 7, 11), dtype='float32')
dummy_shape = tensor.lvector()
dummy_one_shape = tensor.ones(4, dtype='int64')
constant_vec_shape = tensor.constant([3, 5, 7, 11])
tuple_shape = (3, 5, 7, 11)
list_shape = list(tuple_shape)
constant_list_shape = [tensor.constant(i, dtype='int64')
for i in tuple_shape]
constant_tuple_shape = tuple(constant_list_shape)
bad_shapes = (
dummy_shape,
dummy_one_shape,
dummy_t4.shape,
alloc_dummy_t4.shape,
constant_vec_shape,
)
good_shapes = (
constant_list_shape,
constant_tuple_shape,
tuple_shape,
list_shape
)
ops_to_test = (
AbstractConv2d,
AbstractConv2d_gradInputs,
AbstractConv2d_gradWeights
)
for op in ops_to_test:
for shp in bad_shapes:
assert_raises(ValueError, op, imshp=shp)
assert_raises(ValueError, op, kshp=shp)
for shp in good_shapes:
op(imshp=shp)
op(kshp=shp)
class TestConvTypes(unittest.TestCase):
def setUp(self):
self.input = tensor.ftensor4()
self.filters = tensor.ftensor4()
self.topgrad = tensor.ftensor4()
self.constant_tensor = numpy.zeros((3, 5, 7, 11), dtype='float32')
def test_grad_types(self):
# This function simply tests the behaviour of the AbstractConv
# Ops, not their optimizations
......@@ -431,3 +484,48 @@ class TestConvTypes(unittest.TestCase):
grad_filters, grad_filters.type, filters, filters.type)
assert grad_topgrad.type == topgrad.type, (
grad_topgrad, grad_topgrad.type, topgrad, topgrad.type)
def test_constant_input(self):
# Check the AbstractConv Ops for constant inputs
input = self.input
filters = self.filters
topgrad = self.topgrad
constant_tensor = self.constant_tensor
out_shape = tensor.lvector()
# Check the forward Op
output = conv.conv2d(constant_tensor, filters)
grad_filters = theano.grad(output.sum(), wrt=filters)
assert grad_filters.type == filters.type, (
grad_filters, grad_filters.type, filters, filters.type)
output = conv.conv2d(input, constant_tensor)
grad_input = theano.grad(output.sum(), wrt=input)
assert grad_input.type == input.type, (
grad_input, grad_input.type, input, input.type)
# Check grad wrt weights
grad_filters = conv.AbstractConv2d_gradWeights()(
constant_tensor, topgrad, out_shape)
grad_topgrad = theano.grad(grad_filters.sum(), wrt=topgrad)
assert grad_topgrad.type == topgrad.type, (
grad_topgrad, grad_topgrad.type, topgrad, topgrad.type)
grad_filters = conv.AbstractConv2d_gradWeights()(
input, constant_tensor, out_shape)
grad_input = theano.grad(grad_filters.sum(), wrt=input)
assert grad_input.type == input.type, (
grad_input, grad_input.type, input, input.type)
# Check grad wrt inputs
grad_input = conv.AbstractConv2d_gradInputs()(
constant_tensor, topgrad, out_shape)
grad_topgrad = theano.grad(grad_input.sum(), wrt=topgrad)
assert grad_topgrad.type == topgrad.type, (
grad_topgrad, grad_topgrad.type, topgrad, topgrad.type)
grad_input = conv.AbstractConv2d_gradInputs()(
filters, constant_tensor, out_shape)
grad_filters = theano.grad(grad_input.sum(), wrt=filters)
assert grad_filters.type == filters.type, (
grad_filters, grad_filters.type, filters, filters.type)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论