提交 4ba32765 authored 作者: Hengjean's avatar Hengjean

Fix and improvements.

上级 50dd89db
......@@ -97,7 +97,7 @@ class Append(Op):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
return hash(type(self)) ^ self.inplace
def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType)
......@@ -131,7 +131,7 @@ class Extend(Op):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
return hash(type(self)) ^ self.inplace
def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType)
......@@ -159,10 +159,10 @@ class Insert(Op):
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other)
return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self))
return hash(type(self)) ^ self.inplace
def make_node(self, x, index, toInsert):
assert isinstance(x.type, TypedListType)
......@@ -194,10 +194,10 @@ class Remove(Op):
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other)
return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self))
return hash(type(self)) ^ self.inplace
def make_node(self, x, toRemove):
assert isinstance(x.type, TypedListType)
......@@ -238,10 +238,10 @@ class Reverse(Op):
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other)
return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self))
return hash(type(self)) ^ self.inplace
def make_node(self, x):
assert isinstance(x.type, TypedListType)
......@@ -281,7 +281,7 @@ class Index(Op):
array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list
"""
for y in range(x.__len__()):
for y in range(len(x)):
if self.values_eq(x[y], elem):
out[0] = numpy.asarray(y, dtype=theano.config.floatX)
break
......@@ -313,7 +313,7 @@ class Count(Op):
being thrown when trying to remove a matrix from a matrices list
"""
out[0] = 0
for y in range(x.__len__()):
for y in range(len(x)):
if self.values_eq(x[y], elem):
out[0] += 1
out[0] = numpy.asarray(out[0], dtype=theano.config.floatX)
......
from theano import gof
from theano import numpy
class TypedListType(gof.Type):
......@@ -67,14 +67,11 @@ class TypedListType(gof.Type):
return 0
def values_eq(self, a, b):
if not a.__len__() == b.__len__():
if not len(a) == len(b):
return False
equal = True
for x in range(a.__len__()):
for x in range(len(a)):
if not self.ttype.values_eq(a[x], b[x]):
equal = False
break
return False
return equal
return True
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论