提交 cd3134be authored 作者: Hengjean's avatar Hengjean

Added remove op

上级 ba9d8c57
...@@ -4,6 +4,8 @@ from theano.gof import Apply, Constant, Op, Variable ...@@ -4,6 +4,8 @@ from theano.gof import Apply, Constant, Op, Variable
from theano.tensor.type_other import SliceType from theano.tensor.type_other import SliceType
from theano import tensor as T from theano import tensor as T
import numpy
class _typed_list_py_operators: class _typed_list_py_operators:
...@@ -146,7 +148,7 @@ class Insert(Op): ...@@ -146,7 +148,7 @@ class Insert(Op):
index = index = T.constant(index, ndim=0) index = index = T.constant(index, ndim=0)
else: else:
assert isinstance(index, T.TensorVariable) and index.ndim == 0 assert isinstance(index, T.TensorVariable) and index.ndim == 0
return Apply(self, [x, index, toInsert], [x.ttype()]) return Apply(self, [x, index, toInsert], [x.type()])
def perform(self, node, (x, index, toInsert), (out, )): def perform(self, node, (x, index, toInsert), (out, )):
if not self.inplace: if not self.inplace:
...@@ -157,3 +159,45 @@ class Insert(Op): ...@@ -157,3 +159,45 @@ class Insert(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
class Remove(Op):
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, toRemove):
assert isinstance(x.type, TypedListType)
assert x.ttype == toRemove.type
return Apply(self, [x, toRemove], [x.type()])
def perform(self, node, (x, toRemove), (out, )):
if not self.inplace:
out[0] = list(x)
else:
out[0] = x
"""
inelegant workaround for ValueError: The truth value of an
array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list
"""
if isinstance(toRemove, numpy.ndarray):
for y in x:
if numpy.array_equal(y, toRemove):
toRemove = y
break
out[0].remove(toRemove)
def __str__(self):
return self.__class__.__name__
...@@ -8,7 +8,7 @@ from theano import tensor as T ...@@ -8,7 +8,7 @@ from theano import tensor as T
from theano.tensor.type_other import SliceType from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Insert, from theano.typed_list.basic import (GetItem, Insert,
Append, Extend) Append, Extend, Remove)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -212,7 +212,6 @@ class test_extend(unittest.TestCase): ...@@ -212,7 +212,6 @@ class test_extend(unittest.TestCase):
class test_insert(unittest.TestCase): class test_insert(unittest.TestCase):
#FAILING ValueError: expected an ndarray
def test_inplace(self): def test_inplace(self):
mySymbolicMatricesList = TypedListType(T.TensorType( mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))() theano.config.floatX, (False, False)))()
...@@ -245,3 +244,37 @@ class test_insert(unittest.TestCase): ...@@ -245,3 +244,37 @@ class test_insert(unittest.TestCase):
y = rand_ranged_matrix(-1000, 1000, [100, 101]) y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1), y), [x, y])) self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1), y), [x, y]))
class test_remove(unittest.TestCase):
def test_inplace(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
z = Remove(True)(mySymbolicMatricesList, myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z,
accept_inplace=True)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x, y], y), [x]))
def test_sanity_check(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
z = Remove()(mySymbolicMatricesList, myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x, y], y), [x]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论