提交 576f6eaf authored 作者: Hengjean's avatar Hengjean

Merged inplace and non inplace append and extend

上级 fc64ccaa
......@@ -91,6 +91,12 @@ class Append(Op):
"""
#append an element at the end of another list
"""
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other)
......@@ -103,17 +109,26 @@ class Append(Op):
return Apply(self, [x, toAppend], [x.type()])
def perform(self, node, (x, toAppend), (out, )):
out[0] = list(x)
if self.inplace:
out[0] = list(x)
else:
out[0] = x
out[0].append(toAppend)
def __str__(self):
return self.__class__.__name__
class ExtendInplace(Op):
class Extend(Op):
"""
append all element of a list at the end of another list
"""
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other)
......@@ -126,19 +141,18 @@ class ExtendInplace(Op):
return Apply(self, [x, toAppend], [x.type()])
def perform(self, node, (x, toAppend), (out, )):
out[0] = x
if self.inplace:
out[0] = list(x)
else:
out[0] = x
out[0].extend(toAppend)
def __str__(self):
return self.__class__.__name__
destroy_map = {0: [0]}
class Insert(Op):
class Extend(Op):
"""
append all element of a list at the end of another list
"""
def __eq__(self, other):
return type(self) == type(other)
......@@ -150,7 +164,7 @@ class Extend(Op):
assert x.type == toAppend.type
return Apply(self, [x, toAppend], [x.type()])
def perform(self, node, (x, toAppend), (out, )):
def perform(self, node, (x, index, toAppend), (out, )):
out[0] = list(x)
out[0].extend(toAppend)
......
......@@ -7,7 +7,7 @@ import theano.typed_list
from theano import tensor as T
from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, AppendInplace, ExtendInplace,
from theano.typed_list.basic import (GetItem,
Append, Extend)
from theano.tests import unittest_tools as utt
......@@ -114,7 +114,7 @@ class test_append(unittest.TestCase):
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
z = AppendInplace()(mySymbolicMatricesList, myMatrix)
z = Append(True)(mySymbolicMatricesList, myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z,
accept_inplace=True)
......@@ -164,7 +164,7 @@ class test_extend(unittest.TestCase):
mySymbolicMatricesList2 = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = ExtendInplace()(mySymbolicMatricesList1, mySymbolicMatricesList2)
z = Extend(True)(mySymbolicMatricesList1, mySymbolicMatricesList2)
f = theano.function([mySymbolicMatricesList1, mySymbolicMatricesList2],
z, accept_inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论