提交 5617dc91 authored 作者: Hengjean's avatar Hengjean

Fix and improvement.

上级 c2811e8e
......@@ -109,7 +109,7 @@ class Append(Op):
return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self)) ^ self.inplace
return hash(type(self)) ^ hash(self.inplace)
def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType)
......@@ -139,6 +139,9 @@ class Append(Op):
%(output_name)s = %(x_name)s;
""" % locals()
return init + """
if(%(output_name)s==NULL){
%(fail)s
};
if(PyList_Append( (PyObject*) %(output_name)s,(PyObject*) %(toAppend)s)){
%(fail)s
};
......@@ -162,7 +165,7 @@ class Extend(Op):
return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self)) ^ self.inplace
return hash(type(self)) ^ hash(self.inplace)
def make_node(self, x, toAppend):
assert isinstance(x.type, TypedListType)
......@@ -219,7 +222,7 @@ class Insert(Op):
return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self)) ^ self.inplace
return hash(type(self)) ^ hash(self.inplace)
def make_node(self, x, index, toInsert):
assert isinstance(x.type, TypedListType)
......@@ -253,6 +256,9 @@ class Insert(Op):
%(output_name)s = %(x_name)s;
""" % locals()
return init + """
if(%(output_name)s==NULL){
%(fail)s
};
if(PyList_Insert((PyObject*) %(output_name)s, *((double *) PyArray_DATA(%(index)s)), (PyObject*) %(toInsert)s)==-1){
%(fail)s
};
......@@ -273,7 +279,7 @@ class Remove(Op):
return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self)) ^ self.inplace
return hash(type(self)) ^ hash(self.inplace)
def make_node(self, x, toRemove):
assert isinstance(x.type, TypedListType)
......@@ -292,13 +298,10 @@ class Remove(Op):
array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list
"""
if isinstance(toRemove, numpy.ndarray):
for y in range(out[0].__len__()):
if numpy.array_equal(out[0][y], toRemove):
for y in range(out[0].__len__()):
if node.inputs[0].ttype.values_eq(out[0][y], toRemove):
del out[0][y]
break
else:
out[0].remove(toRemove)
def __str__(self):
return self.__class__.__name__
......@@ -317,7 +320,7 @@ class Reverse(Op):
return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self)) ^ self.inplace
return hash(type(self)) ^ hash(self.inplace)
def make_node(self, x):
assert isinstance(x.type, TypedListType)
......@@ -347,6 +350,9 @@ class Reverse(Op):
%(output_name)s = %(x_name)s;
""" % locals()
return init + """
if(%(output_name)s==NULL){
%(fail)s
};
if(PyList_Reverse((PyObject*) %(output_name)s)==-1){
%(fail)s
};
......@@ -367,7 +373,6 @@ class Index(Op):
def make_node(self, x, elem):
assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type
self.values_eq = x.ttype.values_eq
return Apply(self, [x, elem], [T.scalar()])
def perform(self, node, (x, elem), (out, )):
......@@ -377,7 +382,7 @@ class Index(Op):
being thrown when trying to remove a matrix from a matrices list
"""
for y in range(len(x)):
if self.values_eq(x[y], elem):
if node.inputs[0].ttype.values_eq(x[y], elem):
out[0] = numpy.asarray(y, dtype=theano.config.floatX)
break
......@@ -398,7 +403,6 @@ class Count(Op):
def make_node(self, x, elem):
assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type
self.values_eq = x.ttype.values_eq
return Apply(self, [x, elem], [T.scalar()])
def perform(self, node, (x, elem), (out, )):
......@@ -409,7 +413,7 @@ class Count(Op):
"""
out[0] = 0
for y in range(len(x)):
if self.values_eq(x[y], elem):
if node.inputs[0].ttype.values_eq(x[y], elem):
out[0] += 1
out[0] = numpy.asarray(out[0], dtype=theano.config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论