提交 4765f30e authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Additional correctins on class Unique

上级 83570378
import numpy as np import numpy as np
import numpy import numpy
import warnings import warnings
import theano
import theano
from theano.tensor import basic from theano.tensor import basic
from theano.tensor import nlinalg from theano.tensor import nlinalg
from theano import gof, scalar from theano import gof, scalar
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
tensor = basic tensor = basic
...@@ -1014,29 +1013,24 @@ class Unique(theano.Op): ...@@ -1014,29 +1013,24 @@ class Unique(theano.Op):
return_counts=False): return_counts=False):
self.return_index = return_index self.return_index = return_index
self.return_inverse = return_inverse self.return_inverse = return_inverse
self.return_counts = return_counts self.return_counts = return_counts
if self.return_counts == True and np.__version__ < "1.9.0" : numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
if self.return_counts == True and bool(numpy_ver < [1, 9]) :
raise RuntimeError( raise RuntimeError(
"Numpy version = " + np.__version__ + "Numpy version = " + np.__version__ +
". Option 'return_counts=True' works starting" ". Option 'return_counts=True' works starting"
" from version 1.9.0.") " from version 1.9.0.")
def make_node(self, x): def make_node(self, x):
x = basic.as_tensor_variable(x) x = basic.as_tensor_variable(x)
#x = x.flatten() outputs = [basic.TensorType(broadcastable=[False], dtype=x.dtype)()]
outputs = []
# output0 = basic.TensorType(broadcastable=[False], dtype=x.dtype)()
output0 = x.flatten().type()
outputs.append(output0)
typ = basic.TensorType(broadcastable=[False], dtype='int64') typ = basic.TensorType(broadcastable=[False], dtype='int64')
if self.return_index : if self.return_index :
outputs.append(typ()) outputs.append(typ())
if self.return_inverse : if self.return_inverse :
outputs.append(typ()) outputs.append(typ())
if self.return_counts : if self.return_counts :
outputs.append(typ()) outputs.append(typ())
return theano.Apply(self, [x], outputs) return theano.Apply(self, [x], outputs)
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
...@@ -1061,7 +1055,7 @@ class Unique(theano.Op): ...@@ -1061,7 +1055,7 @@ class Unique(theano.Op):
def infer_shape(self, node, i0_shapes): def infer_shape(self, node, i0_shapes):
ret = node.fgraph.shape_feature.default_infer_shape(node, i0_shapes) ret = node.fgraph.shape_feature.default_infer_shape(node, i0_shapes)
if self.return_inverse : if self.return_inverse :
shape = (np.prod(i0_shapes[0]), ) shape = (basic.prod(i0_shapes[0]), )
if self.return_index : if self.return_index :
ret[2] = shape ret[2] = shape
return ret return ret
......
...@@ -2,10 +2,9 @@ from nose.plugins.attrib import attr ...@@ -2,10 +2,9 @@ from nose.plugins.attrib import attr
import numpy as np import numpy as np
import numpy import numpy
import unittest import unittest
import theano
import theano
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.extra_ops import (CumsumOp, cumsum, CumprodOp, cumprod, from theano.tensor.extra_ops import (CumsumOp, cumsum, CumprodOp, cumprod,
BinCountOp, bincount, DiffOp, diff, BinCountOp, bincount, DiffOp, diff,
squeeze, compress, RepeatOp, repeat, squeeze, compress, RepeatOp, repeat,
...@@ -671,7 +670,7 @@ class test_Unique(utt.InferShapeTester): ...@@ -671,7 +670,7 @@ class test_Unique(utt.InferShapeTester):
Unique(True), Unique(True),
Unique(False, True), Unique(False, True),
Unique(True, True)] Unique(True, True)]
if np.__version__ >= "1.9.0" : if bool(numpy_ver >= [1, 9]) :
self.ops.extend([ self.ops.extend([
Unique(False, False, True), Unique(False, False, True),
Unique(True, False, True), Unique(True, False, True),
...@@ -689,7 +688,7 @@ class test_Unique(utt.InferShapeTester): ...@@ -689,7 +688,7 @@ class test_Unique(utt.InferShapeTester):
np.unique(inp, True), np.unique(inp, True),
np.unique(inp, False, True), np.unique(inp, False, True),
np.unique(inp, True, True)] np.unique(inp, True, True)]
if np.__version__ >= "1.9.0" : if bool(numpy_ver >= [1, 9]) :
list_outs_expected.extend([ list_outs_expected.extend([
np.unique(inp, False, False, True), np.unique(inp, False, False, True),
np.unique(inp, True, False, True), np.unique(inp, True, False, True),
...@@ -698,11 +697,8 @@ class test_Unique(utt.InferShapeTester): ...@@ -698,11 +697,8 @@ class test_Unique(utt.InferShapeTester):
for op, outs_expected in zip(self.ops, list_outs_expected) : for op, outs_expected in zip(self.ops, list_outs_expected) :
f = theano.function(inputs=[x], outputs=op(x, return_list=True)) f = theano.function(inputs=[x], outputs=op(x, return_list=True))
outs = f(inp) outs = f(inp)
print outs
# Compare the result computed to the expected value. # Compare the result computed to the expected value.
for out, out_exp in zip(outs, outs_expected): for out, out_exp in zip(outs, outs_expected):
print out
print out_exp
utt.assert_allclose(out, out_exp) utt.assert_allclose(out, out_exp)
def test_basic_matrix(self): def test_basic_matrix(self):
...@@ -715,7 +711,7 @@ class test_Unique(utt.InferShapeTester): ...@@ -715,7 +711,7 @@ class test_Unique(utt.InferShapeTester):
np.unique(inp, True), np.unique(inp, True),
np.unique(inp, False, True), np.unique(inp, False, True),
np.unique(inp, True, True)] np.unique(inp, True, True)]
if np.__version__ >= "1.9.0" : if bool(numpy_ver >= [1, 9]) :
list_outs_expected.extend([ list_outs_expected.extend([
np.unique(inp, False, False, True), np.unique(inp, False, False, True),
np.unique(inp, True, False, True), np.unique(inp, True, False, True),
...@@ -724,11 +720,8 @@ class test_Unique(utt.InferShapeTester): ...@@ -724,11 +720,8 @@ class test_Unique(utt.InferShapeTester):
for op, outs_expected in zip(self.ops, list_outs_expected): for op, outs_expected in zip(self.ops, list_outs_expected):
f = theano.function(inputs=[x], outputs=op(x, return_list=True)) f = theano.function(inputs=[x], outputs=op(x, return_list=True))
outs = f(inp) outs = f(inp)
print outs
# Compare the result computed to the expected value. # Compare the result computed to the expected value.
for out, out_exp in zip(outs, outs_expected): for out, out_exp in zip(outs, outs_expected):
print out
print out_exp
utt.assert_allclose(out, out_exp) utt.assert_allclose(out, out_exp)
def test_infer_shape_vector(self): def test_infer_shape_vector(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论