提交 5617dc91 authored 作者: Hengjean's avatar Hengjean

Fix and improvement.

上级 c2811e8e
...@@ -109,7 +109,7 @@ class Append(Op): ...@@ -109,7 +109,7 @@ class Append(Op):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ self.inplace return hash(type(self)) ^ hash(self.inplace)
def make_node(self, x, toAppend): def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
...@@ -139,6 +139,9 @@ class Append(Op): ...@@ -139,6 +139,9 @@ class Append(Op):
%(output_name)s = %(x_name)s; %(output_name)s = %(x_name)s;
""" % locals() """ % locals()
return init + """ return init + """
if(%(output_name)s==NULL){
%(fail)s
};
if(PyList_Append( (PyObject*) %(output_name)s,(PyObject*) %(toAppend)s)){ if(PyList_Append( (PyObject*) %(output_name)s,(PyObject*) %(toAppend)s)){
%(fail)s %(fail)s
}; };
...@@ -162,7 +165,7 @@ class Extend(Op): ...@@ -162,7 +165,7 @@ class Extend(Op):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ self.inplace return hash(type(self)) ^ hash(self.inplace)
def make_node(self, x, toAppend): def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
...@@ -219,7 +222,7 @@ class Insert(Op): ...@@ -219,7 +222,7 @@ class Insert(Op):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ self.inplace return hash(type(self)) ^ hash(self.inplace)
def make_node(self, x, index, toInsert): def make_node(self, x, index, toInsert):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
...@@ -253,6 +256,9 @@ class Insert(Op): ...@@ -253,6 +256,9 @@ class Insert(Op):
%(output_name)s = %(x_name)s; %(output_name)s = %(x_name)s;
""" % locals() """ % locals()
return init + """ return init + """
if(%(output_name)s==NULL){
%(fail)s
};
if(PyList_Insert((PyObject*) %(output_name)s, *((double *) PyArray_DATA(%(index)s)), (PyObject*) %(toInsert)s)==-1){ if(PyList_Insert((PyObject*) %(output_name)s, *((double *) PyArray_DATA(%(index)s)), (PyObject*) %(toInsert)s)==-1){
%(fail)s %(fail)s
}; };
...@@ -273,7 +279,7 @@ class Remove(Op): ...@@ -273,7 +279,7 @@ class Remove(Op):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ self.inplace return hash(type(self)) ^ hash(self.inplace)
def make_node(self, x, toRemove): def make_node(self, x, toRemove):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
...@@ -292,13 +298,10 @@ class Remove(Op): ...@@ -292,13 +298,10 @@ class Remove(Op):
array with more than one element is ambiguous. Use a.any() or a.all() array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list being thrown when trying to remove a matrix from a matrices list
""" """
if isinstance(toRemove, numpy.ndarray): for y in range(out[0].__len__()):
for y in range(out[0].__len__()): if node.inputs[0].ttype.values_eq(out[0][y], toRemove):
if numpy.array_equal(out[0][y], toRemove):
del out[0][y] del out[0][y]
break break
else:
out[0].remove(toRemove)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
...@@ -317,7 +320,7 @@ class Reverse(Op): ...@@ -317,7 +320,7 @@ class Reverse(Op):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ self.inplace return hash(type(self)) ^ hash(self.inplace)
def make_node(self, x): def make_node(self, x):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
...@@ -347,6 +350,9 @@ class Reverse(Op): ...@@ -347,6 +350,9 @@ class Reverse(Op):
%(output_name)s = %(x_name)s; %(output_name)s = %(x_name)s;
""" % locals() """ % locals()
return init + """ return init + """
if(%(output_name)s==NULL){
%(fail)s
};
if(PyList_Reverse((PyObject*) %(output_name)s)==-1){ if(PyList_Reverse((PyObject*) %(output_name)s)==-1){
%(fail)s %(fail)s
}; };
...@@ -367,7 +373,6 @@ class Index(Op): ...@@ -367,7 +373,6 @@ class Index(Op):
def make_node(self, x, elem): def make_node(self, x, elem):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type assert x.ttype == elem.type
self.values_eq = x.ttype.values_eq
return Apply(self, [x, elem], [T.scalar()]) return Apply(self, [x, elem], [T.scalar()])
def perform(self, node, (x, elem), (out, )): def perform(self, node, (x, elem), (out, )):
...@@ -377,7 +382,7 @@ class Index(Op): ...@@ -377,7 +382,7 @@ class Index(Op):
being thrown when trying to remove a matrix from a matrices list being thrown when trying to remove a matrix from a matrices list
""" """
for y in range(len(x)): for y in range(len(x)):
if self.values_eq(x[y], elem): if node.inputs[0].ttype.values_eq(x[y], elem):
out[0] = numpy.asarray(y, dtype=theano.config.floatX) out[0] = numpy.asarray(y, dtype=theano.config.floatX)
break break
...@@ -398,7 +403,6 @@ class Count(Op): ...@@ -398,7 +403,6 @@ class Count(Op):
def make_node(self, x, elem): def make_node(self, x, elem):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type assert x.ttype == elem.type
self.values_eq = x.ttype.values_eq
return Apply(self, [x, elem], [T.scalar()]) return Apply(self, [x, elem], [T.scalar()])
def perform(self, node, (x, elem), (out, )): def perform(self, node, (x, elem), (out, )):
...@@ -409,7 +413,7 @@ class Count(Op): ...@@ -409,7 +413,7 @@ class Count(Op):
""" """
out[0] = 0 out[0] = 0
for y in range(len(x)): for y in range(len(x)):
if self.values_eq(x[y], elem): if node.inputs[0].ttype.values_eq(x[y], elem):
out[0] += 1 out[0] += 1
out[0] = numpy.asarray(out[0], dtype=theano.config.floatX) out[0] = numpy.asarray(out[0], dtype=theano.config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论