提交 6c853649 authored 作者: Hengjean's avatar Hengjean

Implemented basic typed list

上级 47a812ab
from type import TypedListType
import unittest
from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr
import numpy
from numpy.testing import dec, assert_array_equal, assert_allclose
from numpy.testing.noseclasses import KnownFailureTest
import theano
import theano.typed_list
from theano import tensor as T
from theano.typed_list.type import TypedListType
from theano.tests import unittest_tools as utt
#took from tensors/tests/test_basic.py
def rand_ranged_matrix(minimum, maximum, shape):
return numpy.asarray(numpy.random.rand(*shape) * (maximum - minimum)
+ minimum, dtype=theano.config.floatX)
class test_typed_list_type(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_wrong_input_on_creation(self):
"""
Typed list type should raises an
error if the argument passed for
type is not a valid theano type
"""
self.assertRaises(TypeError, TypedListType, None)
def test_wrong_input_on_filter(self):
"""
Typed list type should raises an
error if the argument given to filter
isn't of the same type as the one
specified on creation
"""
#list of matrices
myType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
self.assertRaises(TypeError, myType.filter, [4])
def test_not_a_list_on_filter(self):
"""
Typed List Value should raises an error
if no iterable variable is given on input
"""
#list of matrices
myType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
self.assertRaises(TypeError, myType.filter, 4)
def test_type_equality(self):
"""
Typed list types should only be equal
when they contains the same theano
variables
"""
#list of matrices
myType1 = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
#list of matrices
myType2 = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
#list of scalars
myType3 = TypedListType(T.TensorType(theano.config.floatX,
()))
self.assertTrue(myType2 == myType1)
self.assertFalse(myType3 == myType1)
def test_filter_sanity_check(self):
"""
Simple test on typed list type filter
"""
myType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
x = rand_ranged_matrix(-1000, 1000, [100, 100])
self.assertTrue(numpy.array_equal(myType.filter([x]), [x]))
def test_intern_filter(self):
"""
(supposing theano.config.floatX = floatX)
Test checking if values contained are themselves
filtered. If they weren't this code would raise
an exception.
"""
myType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
x = numpy.asarray([[4, 5], [4, 5]], dtype='Float32')
self.assertTrue(numpy.array_equal(myType.filter([x]), [x]))
#Will fail for unknown reasons
#under search
"""
def test_load(self):
myType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
x = rand_ranged_matrix(-1000, 1000, [100, 100])
testList = []
for i in range(10000):
testList.append(x)
self.assertTrue(numpy.array_equal(myType.filter(testList), testList))
"""
def test_basic_nested_list(self):
"""
Testing nested list with one level of depth
"""
myNestedType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
myType = TypedListType(myNestedType)
x = rand_ranged_matrix(-1000, 1000, [100, 100])
self.assertTrue(numpy.array_equal(myType.filter([[x]]), [[x]]))
def test_comparison_different_depth(self):
"""
Nested list with different depth aren't the same
"""
myNestedType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
myNestedType2 = TypedListType(myNestedType)
myNestedType3 = TypedListType(myNestedType2)
self.assertFalse(myNestedType2 == myNestedType3)
def test_nested_list_arg(self):
"""
test for the 'depth' optionnal argument
"""
myNestedType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)), 3)
myType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
myManualNestedType = TypedListType(TypedListType(
TypedListType(myType)))
self.assertTrue(myNestedType == myManualNestedType)
def test_get_depth(self):
"""
test case for get_depth utilitary function
"""
myType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
myManualNestedType = TypedListType(TypedListType(
TypedListType(myType)))
self.assertTrue(myManualNestedType.get_depth() == 3)
def test_comparison_uneven_nested(self):
"""
test for comparison between uneven nested list
"""
myType = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))
myManualNestedType1 = TypedListType(TypedListType(
TypedListType(myType)))
myManualNestedType2 = TypedListType(TypedListType(
myType))
self.assertFalse(myManualNestedType1 == myManualNestedType2)
self.assertFalse(myManualNestedType2 == myManualNestedType1)
\ No newline at end of file
from theano import gof
class TypedListType(gof.Type):
def __init__(self, ttype, depth=0):
"""
:Parameters:
-'ttype' : Type of theano variable this list
will contains, can be another list.
-'depth' : Optionnal parameters, any value
above 0 will create a nested list of this
depth. (0-based)
"""
if depth < 0:
raise ValueError('Please specify a depth superior or'
'equal to 0')
if not hasattr(ttype, 'is_valid_value'):
raise TypeError('Expected a Theano type')
if depth == 0:
self.ttype = ttype
else:
self.ttype = TypedListType(ttype, depth - 1)
def filter(self, x, strict=False, allow_downcast=None):
"""
:Parameters:
-'x' : value to filter
-'strict' : if true, only native python list will be accepted
-'allow_downcast' : does not have any utility at the moment
"""
if strict:
if not isinstance(x, list):
raise TypeError('Expected a python list')
else:
x = list(x)
#check all member of list are of the same type
#for the moment only one dimension list accepted
x = [self.ttype.filter(y) for y in x]
if all(self.ttype.is_valid_value(y) for y in x):
return x
else:
raise TypeError('Expected all elements to'
' be %s' % str(self.ttype))
def __eq__(self, other):
"""
two list are equals if they contains the same type.
"""
if not hasattr(other, 'ttype'):
return False
return (self.ttype == other.ttype)
def __str__(self):
return 'Typed List <' + str(self.ttype) + '>'
def get_depth(self):
"""
utilitary function to get the 0 based
level of the list
"""
if hasattr(self.ttype, 'get_depth'):
return self.ttype.get_depth() + 1
else:
return 0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论