提交 48a057db authored 作者: Hengjean's avatar Hengjean

Fixed bug.

上级 e19bfe0d
...@@ -260,6 +260,7 @@ class Index(Op): ...@@ -260,6 +260,7 @@ class Index(Op):
def make_node(self, x, elem): def make_node(self, x, elem):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type assert x.ttype == elem.type
self.values_eq = x.ttype.values_eq
return Apply(self, [x, elem], [T.scalar()]) return Apply(self, [x, elem], [T.scalar()])
def perform(self, node, (x, elem), (out, )): def perform(self, node, (x, elem), (out, )):
...@@ -268,13 +269,10 @@ class Index(Op): ...@@ -268,13 +269,10 @@ class Index(Op):
array with more than one element is ambiguous. Use a.any() or a.all() array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list being thrown when trying to remove a matrix from a matrices list
""" """
if isinstance(elem, numpy.ndarray): for y in range(x.__len__()):
for y in range(x.__len__()): if self.values_eq(x[y], elem):
if numpy.array_equal(x[y], elem): out[0] = numpy.asarray(y, dtype=theano.config.floatX)
out[0] = numpy.asarray([y], dtype=theano.config.floatX) break
break
else:
out[0] = numpy.asarray([x.index(elem)], dtype=theano.config.floatX)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
...@@ -291,6 +289,7 @@ class Count(Op): ...@@ -291,6 +289,7 @@ class Count(Op):
def make_node(self, x, elem): def make_node(self, x, elem):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type assert x.ttype == elem.type
self.values_eq = x.ttype.values_eq
return Apply(self, [x, elem], [T.scalar()]) return Apply(self, [x, elem], [T.scalar()])
def perform(self, node, (x, elem), (out, )): def perform(self, node, (x, elem), (out, )):
...@@ -299,14 +298,11 @@ class Count(Op): ...@@ -299,14 +298,11 @@ class Count(Op):
array with more than one element is ambiguous. Use a.any() or a.all() array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list being thrown when trying to remove a matrix from a matrices list
""" """
if isinstance(elem, numpy.ndarray): out[0] = 0
out[0] = 0 for y in range(x.__len__()):
for y in range(x.__len__()): if self.values_eq(x[y], elem):
if numpy.array_equal(x[y], elem): out[0] += 1
out[0] += 1 out[0] = numpy.asarray(out[0], dtype=theano.config.floatX)
out[0] = numpy.asarray([out[0]], dtype=theano.config.floatX)
else:
out[0] = numpy.asarray([x.count(elem)], dtype=theano.config.floatX)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
...@@ -67,7 +67,14 @@ class TypedListType(gof.Type): ...@@ -67,7 +67,14 @@ class TypedListType(gof.Type):
return 0 return 0
def values_eq(self, a, b): def values_eq(self, a, b):
if isinstance(a, numpy.ndarray): if not a.__len__() == b.__len__():
return numpy.array_equal(a, b) return False
else:
return a == b equal = True
for x in range(a.__len__()):
if not self.ttype.values_eq(a[x], b[x]):
equal = False
break
return equal
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论