提交 ba9d8c57 authored 作者: Hengjean's avatar Hengjean

Added Insert Op

The inplace test for this fails for unknown reasons
上级 576f6eaf
......@@ -62,31 +62,6 @@ class GetItem(Op):
return self.__class__.__name__
class AppendInplace(Op):
"""
#append an element at the end of another list
"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType)
assert x.ttype == toAppend.type
return Apply(self, [x, toAppend], [x.type()])
def perform(self, node, (x, toAppend), (out, )):
out[0] = x
out[0].append(toAppend)
def __str__(self):
return self.__class__.__name__
destroy_map = {0: [0]}
class Append(Op):
"""
#append an element at the end of another list
......@@ -109,7 +84,7 @@ class Append(Op):
return Apply(self, [x, toAppend], [x.type()])
def perform(self, node, (x, toAppend), (out, )):
if self.inplace:
if not self.inplace:
out[0] = list(x)
else:
out[0] = x
......@@ -141,7 +116,7 @@ class Extend(Op):
return Apply(self, [x, toAppend], [x.type()])
def perform(self, node, (x, toAppend), (out, )):
if self.inplace:
if not self.inplace:
out[0] = list(x)
else:
out[0] = x
......@@ -153,20 +128,32 @@ class Extend(Op):
class Insert(Op):
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, toAppend):
def make_node(self, x, index, toInsert):
assert isinstance(x.type, TypedListType)
assert x.type == toAppend.type
return Apply(self, [x, toAppend], [x.type()])
assert x.ttype == toInsert.type
if not isinstance(index, Variable):
index = index = T.constant(index, ndim=0)
else:
assert isinstance(index, T.TensorVariable) and index.ndim == 0
return Apply(self, [x, index, toInsert], [x.ttype()])
def perform(self, node, (x, index, toAppend), (out, )):
out[0] = list(x)
out[0].extend(toAppend)
def perform(self, node, (x, index, toInsert), (out, )):
if not self.inplace:
out[0] = list(x)
else:
out[0] = x
out[0].insert(index, toInsert)
def __str__(self):
return self.__class__.__name__
......@@ -7,7 +7,7 @@ import theano.typed_list
from theano import tensor as T
from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem,
from theano.typed_list.basic import (GetItem, Insert,
Append, Extend)
from theano.tests import unittest_tools as utt
......@@ -208,3 +208,40 @@ class test_extend(unittest.TestCase):
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], [y]), [x, y]))
class test_insert(unittest.TestCase):
#FAILING ValueError: expected an ndarray
def test_inplace(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
myScalar = T.scalar()
z = Insert(True)(mySymbolicMatricesList, myScalar, myMatrix)
f = theano.function([mySymbolicMatricesList, myScalar, myMatrix], z,
accept_inplace=True)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1), y), [x, y]))
def test_sanity_check(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
myScalar = T.scalar()
z = Insert()(mySymbolicMatricesList, myScalar, myMatrix)
f = theano.function([mySymbolicMatricesList, myScalar, myMatrix], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1), y), [x, y]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论