提交 687de47f authored 作者: Hengjean's avatar Hengjean

Added reverse op.

上级 3810c977
......@@ -201,3 +201,32 @@ class Remove(Op):
def __str__(self):
return self.__class__.__name__
class Reverse(Op):
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x):
assert isinstance(x.type, TypedListType)
return Apply(self, [x], [x.type()])
def perform(self, node, inp, (out, )):
if not self.inplace:
out[0] = list(inp[0])
else:
out[0] = inp[0]
out[0].reverse()
def __str__(self):
return self.__class__.__name__
......@@ -8,7 +8,7 @@ from theano import tensor as T
from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Insert,
Append, Extend, Remove)
Append, Extend, Remove, Reverse)
from theano.tests import unittest_tools as utt
......@@ -278,3 +278,35 @@ class test_remove(unittest.TestCase):
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x, y], y), [x]))
class test_reverse(unittest.TestCase):
def test_inplace(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = Reverse(True)(mySymbolicMatricesList)
f = theano.function([mySymbolicMatricesList], z,
accept_inplace=True)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x, y]), [y, x]))
def test_sanity_check(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = Reverse()(mySymbolicMatricesList)
f = theano.function([mySymbolicMatricesList], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x, y]), [y, x]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论