提交 d82eb54a authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1917 from Hengjean/TypedListLength

Added Length Op to typed list.
...@@ -13,6 +13,9 @@ class _typed_list_py_operators: ...@@ -13,6 +13,9 @@ class _typed_list_py_operators:
def __getitem__(self, index): def __getitem__(self, index):
return getitem(self, index) return getitem(self, index)
def __len__(self):
return length(self)
def append(self, toAppend): def append(self, toAppend):
return append(self, toAppend) return append(self, toAppend)
...@@ -438,3 +441,44 @@ class Count(Op): ...@@ -438,3 +441,44 @@ class Count(Op):
return self.__class__.__name__ return self.__class__.__name__
count = Count() count = Count()
class Length(Op):
# See doc in instance of this Op after the class definition.
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x):
assert isinstance(x.type, TypedListType)
return Apply(self, [x], [T.scalar(dtype='int64')])
def perform(self, node, x, (out, )):
out[0] = numpy.asarray(len(x[0]), 'int64')
def __str__(self):
return self.__class__.__name__
def c_code(self, node, name, inp, out, sub):
x_name = inp[0]
output_name = out[0]
fail = sub['fail']
return """
if(!%(output_name)s)
%(output_name)s=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA(%(output_name)s))[0]=PyList_Size((PyObject*)%(x_name)s);
Py_INCREF(%(output_name)s);
""" % locals()
def c_code_cache_version(self):
return (1,)
length = Length()
"""
Returns the size of a list.
:param x: typed list.
"""
...@@ -9,7 +9,7 @@ from theano.tensor.type_other import SliceType ...@@ -9,7 +9,7 @@ from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Insert, from theano.typed_list.basic import (GetItem, Insert,
Append, Extend, Remove, Reverse, Append, Extend, Remove, Reverse,
Index, Count) Index, Count, Length)
from theano import sparse from theano import sparse
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
import scipy.sparse as sp import scipy.sparse as sp
...@@ -511,3 +511,29 @@ class test_count(unittest.TestCase): ...@@ -511,3 +511,29 @@ class test_count(unittest.TestCase):
y = sp.csr_matrix(random_lil((10, 40), theano.config.floatX, 3)) y = sp.csr_matrix(random_lil((10, 40), theano.config.floatX, 3))
self.assertTrue(f([x, y, y], y) == 2) self.assertTrue(f([x, y, y], y) == 2)
class test_length(unittest.TestCase):
def test_sanity_check(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = Length()(mySymbolicMatricesList)
f = theano.function([mySymbolicMatricesList], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(f([x, x, x, x]) == 4)
def test_interface(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = mySymbolicMatricesList.__len__()
f = theano.function([mySymbolicMatricesList], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(f([x, x]) == 2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论