提交 5252b4df authored 作者: Hengjean's avatar Hengjean

Added non inplace append and extend

上级 eb1d2f9f
...@@ -62,7 +62,7 @@ class GetItem(Op): ...@@ -62,7 +62,7 @@ class GetItem(Op):
return self.__class__.__name__ return self.__class__.__name__
class Append(Op): class AppendInplace(Op):
""" """
#append an element at the end of another list #append an element at the end of another list
""" """
...@@ -87,9 +87,9 @@ class Append(Op): ...@@ -87,9 +87,9 @@ class Append(Op):
destroy_map = {0: [0]} destroy_map = {0: [0]}
class Extend(Op): class Append(Op):
""" """
append all element of a list at the end of another list #append an element at the end of another list
""" """
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -98,6 +98,29 @@ class Extend(Op): ...@@ -98,6 +98,29 @@ class Extend(Op):
return hash(type(self)) return hash(type(self))
def make_node(self, x, toAppend): def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType)
assert x.ttype == toAppend.type
return Apply(self, [x, toAppend], [x.type()])
def perform(self, node, (x, toAppend), (out, )):
out[0] = list(x)
out[0].append(toAppend)
def __str__(self):
return self.__class__.__name__
class ExtendInplace(Op):
"""
append all element of a list at the end of another list
"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, toAppend, inplace=False):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert x.type == toAppend.type assert x.type == toAppend.type
return Apply(self, [x, toAppend], [x.type()]) return Apply(self, [x, toAppend], [x.type()])
...@@ -110,3 +133,26 @@ class Extend(Op): ...@@ -110,3 +133,26 @@ class Extend(Op):
return self.__class__.__name__ return self.__class__.__name__
destroy_map = {0: [0]} destroy_map = {0: [0]}
class Extend(Op):
"""
append all element of a list at the end of another list
"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, toAppend, inplace=False):
assert isinstance(x.type, TypedListType)
assert x.type == toAppend.type
return Apply(self, [x, toAppend], [x.type()])
def perform(self, node, (x, toAppend), (out, )):
out[0] = list(x)
out[0].extend(toAppend)
def __str__(self):
return self.__class__.__name__
...@@ -7,7 +7,8 @@ import theano.typed_list ...@@ -7,7 +7,8 @@ import theano.typed_list
from theano import tensor as T from theano import tensor as T
from theano.tensor.type_other import SliceType from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Append, Extend) from theano.typed_list.basic import (GetItem, AppendInplace, ExtendInplace,
Append, Extend)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -113,7 +114,7 @@ class test_append(unittest.TestCase): ...@@ -113,7 +114,7 @@ class test_append(unittest.TestCase):
theano.config.floatX, (False, False)))() theano.config.floatX, (False, False)))()
myMatrix = T.matrix() myMatrix = T.matrix()
z = Append()(mySymbolicMatricesList, myMatrix) z = AppendInplace()(mySymbolicMatricesList, myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z, f = theano.function([mySymbolicMatricesList, myMatrix], z,
accept_inplace=True) accept_inplace=True)
...@@ -124,6 +125,21 @@ class test_append(unittest.TestCase): ...@@ -124,6 +125,21 @@ class test_append(unittest.TestCase):
self.assertTrue(numpy.array_equal(f([x], y), [x, y])) self.assertTrue(numpy.array_equal(f([x], y), [x, y]))
def test_sanity_check(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
z = Append()(mySymbolicMatricesList, myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], y), [x, y]))
class test_extend(unittest.TestCase): class test_extend(unittest.TestCase):
...@@ -133,7 +149,7 @@ class test_extend(unittest.TestCase): ...@@ -133,7 +149,7 @@ class test_extend(unittest.TestCase):
mySymbolicMatricesList2 = TypedListType(T.TensorType( mySymbolicMatricesList2 = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))() theano.config.floatX, (False, False)))()
z = Extend()(mySymbolicMatricesList1, mySymbolicMatricesList2) z = ExtendInplace()(mySymbolicMatricesList1, mySymbolicMatricesList2)
f = theano.function([mySymbolicMatricesList1, mySymbolicMatricesList2], f = theano.function([mySymbolicMatricesList1, mySymbolicMatricesList2],
z, accept_inplace=True) z, accept_inplace=True)
...@@ -143,3 +159,20 @@ class test_extend(unittest.TestCase): ...@@ -143,3 +159,20 @@ class test_extend(unittest.TestCase):
y = rand_ranged_matrix(-1000, 1000, [100, 101]) y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], [y]), [x, y])) self.assertTrue(numpy.array_equal(f([x], [y]), [x, y]))
def test_sanity_check(self):
mySymbolicMatricesList1 = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicMatricesList2 = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = Extend()(mySymbolicMatricesList1, mySymbolicMatricesList2)
f = theano.function([mySymbolicMatricesList1, mySymbolicMatricesList2],
z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], [y]), [x, y]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论