提交 b40baa32 authored 作者: Hengjean's avatar Hengjean

Added append and extend.

上级 200db67b
......@@ -10,6 +10,12 @@ class _typed_list_py_operators:
def __getitem__(self, index):
return get_item()(self, index)
def append(self, toAppend):
return append()(self, toAppend)
def extend(self, toAppend):
return extend()(self, toAppend)
class TypedListVariable(_typed_list_py_operators, Variable):
"""
......@@ -30,6 +36,7 @@ class get_item(Op):
return hash(type(self))
def make_node(self, x, index):
assert isinstance(x.type, TypedListType)
if index.type == SliceType():
return Apply(self, [x, index], [x.type()])
elif isinstance(index, T.TensorVariable) and index.ndim == 0:
......@@ -44,3 +51,49 @@ class get_item(Op):
def __str__(self):
return self.__class__.__name__
class append(Op):
"""
#append an element at the end of another list
"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType)
assert x.ttype == toAppend.type
return Apply(self, [x, toAppend], [x.type()])
def perform(self, node, (x, toAppend), (out, )):
out[0] = x
out[0].append(toAppend)
def __str__(self):
return self.__class__.__name__
class extend(Op):
"""
append all element of a list at the end of another list
"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType)
assert x.type == toAppend.type
return Apply(self, [x, toAppend], [x.type()])
def perform(self, node, (x, toAppend), (out, )):
out[0] = x
out[0].extend(toAppend)
def __str__(self):
return self.__class__.__name__
......@@ -11,7 +11,7 @@ import theano.typed_list
from theano import tensor as T
from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType
from theano.typed_list.basic import get_item
from theano.typed_list.basic import (get_item, append, extend)
from theano.tests import unittest_tools as utt
......@@ -21,7 +21,7 @@ def rand_ranged_matrix(minimum, maximum, shape):
+ minimum, dtype=theano.config.floatX)
class test_get_slice(unittest.TestCase):
class test_get_item(unittest.TestCase):
def setUp(self):
utt.seed_rng()
......@@ -81,3 +81,41 @@ class test_get_slice(unittest.TestCase):
self.assertRaises(TypeError, get_item(), mySymbolicMatricesList,
mySymbolicMatrix)
class test_append(unittest.TestCase):
def test_sanity_check(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
z = append()(mySymbolicMatricesList, myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z)
x = rand_ranged_matrix(-1000, 1000, [100, 100])
y = rand_ranged_matrix(-1000, 1000, [100, 100])
self.assertTrue(numpy.array_equal(f([x], y), [x, y]))
class test_extend(unittest.TestCase):
def test_sanity_check(self):
mySymbolicMatricesList1 = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicMatricesList2 = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = extend()(mySymbolicMatricesList1, mySymbolicMatricesList2)
f = theano.function([mySymbolicMatricesList1, mySymbolicMatricesList2],
z)
x = rand_ranged_matrix(-1000, 1000, [100, 100])
y = rand_ranged_matrix(-1000, 1000, [100, 100])
self.assertTrue(numpy.array_equal(f([x], [y]), [x, y]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论