提交 9e2a2b17 authored 作者: Hengjean's avatar Hengjean

Added index op.

上级 fc40a93c
...@@ -27,6 +27,10 @@ class _typed_list_py_operators: ...@@ -27,6 +27,10 @@ class _typed_list_py_operators:
def reverse(self): def reverse(self):
return Reverse()(self) return Reverse()(self)
#name "index" is already used by an attribute
def ind(self, elem):
return Index()(self, elem)
ttype = property(lambda self: self.type.ttype) ttype = property(lambda self: self.type.ttype)
...@@ -239,3 +243,34 @@ class Reverse(Op): ...@@ -239,3 +243,34 @@ class Reverse(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
class Index(Op):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, elem):
assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type
return Apply(self, [x, elem], [T.scalar()])
def perform(self, node, (x, elem), (out, )):
"""
inelegant workaround for ValueError: The truth value of an
array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list
"""
if isinstance(elem, numpy.ndarray):
for y in range(x.__len__()):
if numpy.array_equal(x[y], elem):
out[0] = numpy.asarray([y])
break
else:
out[0] = numpy.asarray([x.index(elem)])
def __str__(self):
return self.__class__.__name__
...@@ -8,7 +8,8 @@ from theano import tensor as T ...@@ -8,7 +8,8 @@ from theano import tensor as T
from theano.tensor.type_other import SliceType from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Insert, from theano.typed_list.basic import (GetItem, Insert,
Append, Extend, Remove, Reverse) Append, Extend, Remove, Reverse,
Index)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -355,3 +356,53 @@ class test_reverse(unittest.TestCase): ...@@ -355,3 +356,53 @@ class test_reverse(unittest.TestCase):
y = rand_ranged_matrix(-1000, 1000, [100, 101]) y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x, y]), [y, x])) self.assertTrue(numpy.array_equal(f([x, y]), [y, x]))
class test_index(unittest.TestCase):
def test_sanity_check(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
z = Index()(mySymbolicMatricesList, myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(f([x, y], y) == 1)
def test_interface(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
z = mySymbolicMatricesList.ind(myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(f([x, y], y) == 1)
def test_non_tensor_type(self):
mySymbolicNestedMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)), 1)()
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = mySymbolicNestedMatricesList.ind(mySymbolicMatricesList)
f = theano.function([mySymbolicNestedMatricesList,
mySymbolicMatricesList], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(f([[x, y], [x, y, y]], [x, y]) == 0)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论