提交 4ba32765 authored 作者: Hengjean's avatar Hengjean

Fix and improvements.

上级 50dd89db
...@@ -97,7 +97,7 @@ class Append(Op): ...@@ -97,7 +97,7 @@ class Append(Op):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self)) ^ self.inplace
def make_node(self, x, toAppend): def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
...@@ -131,7 +131,7 @@ class Extend(Op): ...@@ -131,7 +131,7 @@ class Extend(Op):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self)) ^ self.inplace
def make_node(self, x, toAppend): def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
...@@ -159,10 +159,10 @@ class Insert(Op): ...@@ -159,10 +159,10 @@ class Insert(Op):
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self)) ^ self.inplace
def make_node(self, x, index, toInsert): def make_node(self, x, index, toInsert):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
...@@ -194,10 +194,10 @@ class Remove(Op): ...@@ -194,10 +194,10 @@ class Remove(Op):
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self)) ^ self.inplace
def make_node(self, x, toRemove): def make_node(self, x, toRemove):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
...@@ -238,10 +238,10 @@ class Reverse(Op): ...@@ -238,10 +238,10 @@ class Reverse(Op):
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self)) ^ self.inplace
def make_node(self, x): def make_node(self, x):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
...@@ -281,7 +281,7 @@ class Index(Op): ...@@ -281,7 +281,7 @@ class Index(Op):
array with more than one element is ambiguous. Use a.any() or a.all() array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list being thrown when trying to remove a matrix from a matrices list
""" """
for y in range(x.__len__()): for y in range(len(x)):
if self.values_eq(x[y], elem): if self.values_eq(x[y], elem):
out[0] = numpy.asarray(y, dtype=theano.config.floatX) out[0] = numpy.asarray(y, dtype=theano.config.floatX)
break break
...@@ -313,7 +313,7 @@ class Count(Op): ...@@ -313,7 +313,7 @@ class Count(Op):
being thrown when trying to remove a matrix from a matrices list being thrown when trying to remove a matrix from a matrices list
""" """
out[0] = 0 out[0] = 0
for y in range(x.__len__()): for y in range(len(x)):
if self.values_eq(x[y], elem): if self.values_eq(x[y], elem):
out[0] += 1 out[0] += 1
out[0] = numpy.asarray(out[0], dtype=theano.config.floatX) out[0] = numpy.asarray(out[0], dtype=theano.config.floatX)
......
from theano import gof from theano import gof
from theano import numpy
class TypedListType(gof.Type): class TypedListType(gof.Type):
...@@ -67,14 +67,11 @@ class TypedListType(gof.Type): ...@@ -67,14 +67,11 @@ class TypedListType(gof.Type):
return 0 return 0
def values_eq(self, a, b): def values_eq(self, a, b):
if not a.__len__() == b.__len__(): if not len(a) == len(b):
return False return False
equal = True for x in range(len(a)):
for x in range(a.__len__()):
if not self.ttype.values_eq(a[x], b[x]): if not self.ttype.values_eq(a[x], b[x]):
equal = False return False
break
return equal return True
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论