提交 7eef6f8d authored 作者: Hengjean's avatar Hengjean

Added Count Op

上级 9e2a2b17
...@@ -27,6 +27,9 @@ class _typed_list_py_operators: ...@@ -27,6 +27,9 @@ class _typed_list_py_operators:
def reverse(self): def reverse(self):
return Reverse()(self) return Reverse()(self)
def count(self, elem):
return Count()(self, elem)
#name "index" is already used by an attribute #name "index" is already used by an attribute
def ind(self, elem): def ind(self, elem):
return Index()(self, elem) return Index()(self, elem)
...@@ -274,3 +277,35 @@ class Index(Op): ...@@ -274,3 +277,35 @@ class Index(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
class Count(Op):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, elem):
assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type
return Apply(self, [x, elem], [T.scalar()])
def perform(self, node, (x, elem), (out, )):
"""
inelegant workaround for ValueError: The truth value of an
array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list
"""
if isinstance(elem, numpy.ndarray):
out[0] = 0
for y in range(x.__len__()):
if numpy.array_equal(x[y], elem):
out[0] += 1
out[0] = numpy.asarray([out[0]])
else:
out[0] = numpy.asarray([x.count(elem)])
def __str__(self):
return self.__class__.__name__
...@@ -9,7 +9,7 @@ from theano.tensor.type_other import SliceType ...@@ -9,7 +9,7 @@ from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Insert, from theano.typed_list.basic import (GetItem, Insert,
Append, Extend, Remove, Reverse, Append, Extend, Remove, Reverse,
Index) Index, Count)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -396,7 +396,7 @@ class test_index(unittest.TestCase): ...@@ -396,7 +396,7 @@ class test_index(unittest.TestCase):
mySymbolicMatricesList = TypedListType(T.TensorType( mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))() theano.config.floatX, (False, False)))()
z = mySymbolicNestedMatricesList.ind(mySymbolicMatricesList) z = Index()(mySymbolicNestedMatricesList, mySymbolicMatricesList)
f = theano.function([mySymbolicNestedMatricesList, f = theano.function([mySymbolicNestedMatricesList,
mySymbolicMatricesList], z) mySymbolicMatricesList], z)
...@@ -406,3 +406,53 @@ class test_index(unittest.TestCase): ...@@ -406,3 +406,53 @@ class test_index(unittest.TestCase):
y = rand_ranged_matrix(-1000, 1000, [100, 101]) y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(f([[x, y], [x, y, y]], [x, y]) == 0) self.assertTrue(f([[x, y], [x, y, y]], [x, y]) == 0)
class test_count(unittest.TestCase):
def test_sanity_check(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
z = Count()(mySymbolicMatricesList, myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(f([y, y, x, y], y) == 3)
def test_interface(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
myMatrix = T.matrix()
z = mySymbolicMatricesList.count(myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(f([x, y], y) == 1)
def test_non_tensor_type(self):
mySymbolicNestedMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)), 1)()
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = Count()(mySymbolicNestedMatricesList, mySymbolicMatricesList)
f = theano.function([mySymbolicNestedMatricesList,
mySymbolicMatricesList], z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(f([[x, y], [x, y, y]], [x, y]) == 1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论