提交 0a876e23 authored 作者: mronian's avatar mronian

Adds test for uint datatype for image_shape and filter_shape in…

Adds test for uint datatype for image_shape and filter_shape in theano.tensor.nnet.conv and also checks if datatype is in tensor.discrete_dtypes
上级 22db3930
......@@ -108,7 +108,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
" information are constant values. We got"
" %s for the image_shape parameter" %
image_shape[i])
assert "int" in str(image_shape[i].dtype)
assert image_shape[i].dtype in theano.tensor.discrete_dtypes
image_shape[i] = int(image_shape[i])
if filter_shape is not None:
filter_shape = list(filter_shape)
......@@ -123,7 +123,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
" information are constant values. We got"
" %s for the filter_shape "
"parameter" % filter_shape[i])
assert "int" in str(filter_shape[i].dtype)
assert filter_shape[i].dtype in theano.tensor.discrete_dtypes
filter_shape[i] = int(filter_shape[i])
if image_shape and filter_shape:
......
......@@ -3,7 +3,6 @@ import time
from nose.plugins.skip import SkipTest
import numpy
import theano
import theano.tensor as T
from theano.tests import unittest_tools as utt
......@@ -154,6 +153,22 @@ class TestConv2D(utt.InferShapeTester):
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full')
# test filter same size as input
def test_uint_image_shape_datatype(self):
"""Tests for uint datatype in image_shape.
"""
self.validate((2, 2, 3, numpy.uint8(3)), (3, 2, 3, 3), 'valid', verify_grad=False)
self.validate((numpy.uint16(2), 2, 3, 3), (3, 2, 3, 3), 'valid', verify_grad=False)
self.validate((2, numpy.uint32(2), 3, 3), (3, 2, 3, 3), 'valid', verify_grad=False)
def test_uint_filter_shape_datatype(self):
"""Tests for uint datatype in filter_shape
"""
self.validate((3, 2, 3, 3), (2, 2, 3, numpy.uint8(3)), 'valid', verify_grad=False)
self.validate((3, 2, 3, 3), (numpy.uint16(2), 2, 3, 3), 'valid', verify_grad=False)
self.validate((3, 2, 3, 3), (2, numpy.uint32(2), 3, 3), 'valid', verify_grad=False)
def test_img_kernel_same_shape(self):
self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'full')
self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'valid')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论