提交 d42672be authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Commit 2

上级 d50fa534
......@@ -2,7 +2,6 @@ import numpy as np
import numpy
import warnings
import theano
import theano.tensor as T
from theano.tensor import basic
from theano.tensor import nlinalg
......@@ -1012,33 +1011,33 @@ class Unique(theano.Op):
__props__ = ("return_index", "return_inverse", "return_counts")
def __init__(self, return_index=False, return_inverse=False,
return_counts=False):
return_counts=False):
self.return_index = return_index
self.return_inverse = return_inverse
self.return_counts = return_counts
if self.return_counts == True and np.__version__ < "1.9.0" :
raise RuntimeError(
"Numpy version = " + np.__version__ +
". Option 'return_counts=True' works starting"
"Numpy version = " + np.__version__ +
". Option 'return_counts=True' works starting"
" from version 1.9.0.")
def make_node(self, x):
x = T.as_tensor_variable(x)
x = basic.as_tensor_variable(x)
#x = x.flatten()
outputs = []
# output0 = T.TensorType(broadcastable=[False], dtype=x.dtype)()
# output0 = basic.TensorType(broadcastable=[False], dtype=x.dtype)()
output0 = x.flatten().type()
outputs.append(output0)
typ = T.TensorType(broadcastable=[False], dtype='int64')
typ = basic.TensorType(broadcastable=[False], dtype='int64')
if self.return_index :
output1 = typ()
outputs.append(output1)
if self.return_inverse :
output2 = typ()#T.TensorType(broadcastable=[False], dtype=x.dtype)
output2 = typ()
outputs.append(output2)
if self.return_counts :
output3 = typ()#T.TensorType(broadcastable=[False], dtype=x.dtype)
output3 = typ()
outputs.append(output3)
return theano.Apply(self, [x], outputs)
......@@ -1056,7 +1055,7 @@ class Unique(theano.Op):
outs = np.unique(x,**param)
if ((not self.return_inverse) and
(not self.return_index) and
(not self.return_counts)):
(not self.return_counts)):
z[0][0]=outs
else :
for i in range(len(outs)):
......@@ -1072,4 +1071,3 @@ class Unique(theano.Op):
ret[1] = shape
return ret
return ret
......@@ -668,7 +668,7 @@ class test_Unique(utt.InferShapeTester):
super(test_Unique, self).setUp()
self.op_class = Unique
self.ops = [Unique(), Unique(True), Unique(True, True),
Unique(False, True)]#, Unique(True, True, True)]
Unique(False, True)]#, Unique(True, True, True)]
def test_basic_vector(self):
"""
......@@ -716,7 +716,7 @@ class test_Unique(utt.InferShapeTester):
Testing the infer_shape with a vector.
"""
# TODO
pass
pass
def test_infer_shape_matrix(self):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论