提交 244e24ec authored 作者: Seon-Wook Park's avatar Seon-Wook Park

Replace CompressOp with .take(.flatnonzero)

上级 a0bf02ba
......@@ -511,53 +511,6 @@ def squeeze(x):
return view
class CompressOp(theano.Op):
# See the compress function for docstring
def __init__(self, axis=None):
self.axis = axis
def __eq__(self, other):
return (type(self) == type(other) and
self.axis == other.axis)
def __hash__(self):
return hash(type(self)) ^ hash(self.axis)
def make_node(self, condition, x):
x = basic.as_tensor_variable(x)
condition = basic.as_tensor_variable(condition)
if condition.ndim != 1:
raise TypeError("Conditions cannot have a number of "
"dimension different of 1.")
return theano.Apply(self, [condition, x], [x.type()])
def perform(self, node, inputs, output_storage):
condition = inputs[0]
x = inputs[1]
z = output_storage[0]
z[0] = np.compress(condition.astype(bool), x, axis=self.axis)
print z[0]
def infer_shape(self, node, ins_shapes):
condition = node.inputs[0]
n = condition.ndim # TODO: Find way to get condition vector shape
if self.axis is None:
out_shape = (n,)
else:
out_shape = list(ins_shapes[1])
out_shape[self.axis] -= n
out_shape = tuple(out_shape)
print out_shape
return [out_shape]
def __str__(self):
return self.__class__.__name__
def compress(condition, x, axis=None, out=None):
"""Return selected slices of an array along given axis.
......@@ -577,7 +530,8 @@ def compress(condition, x, axis=None, out=None):
"""
# This is done to keep the same function signature then NumPy.
assert out is None
return CompressOp(axis=axis)(condition, x)
indices = theano.tensor.basic.flatnonzero(condition)
return x.take(indices, axis=axis)
class RepeatOp(theano.Op):
......
......@@ -7,8 +7,8 @@ from theano.tests import unittest_tools as utt
from theano.tensor.extra_ops import (CumsumOp, cumsum, CumprodOp, cumprod,
BinCountOp, bincount, DiffOp, diff,
squeeze, CompressOp, compress,
RepeatOp, repeat, Bartlett, bartlett,
squeeze, compress, RepeatOp, repeat,
Bartlett, bartlett,
FillDiagonal, fill_diagonal,
FillDiagonalOffset, fill_diagonal_offset,
to_one_hot)
......@@ -344,43 +344,46 @@ class SqueezeTester(utt.InferShapeTester):
assert numpy.allclose(tested, expected)
class TestCompressOp(utt.InferShapeTester):
class CompressTester(utt.InferShapeTester):
axis_list = [None,
0,
1]
cond_list = [[1, 0, 1, 0, 0, 1],
[0, 1, 1, 0],
[1, 1, 0, 1, 0]]
shape_list = [(2, 3),
(4, 3),
(3, 5)]
def setUp(self):
super(TestCompressOp, self).setUp()
self.op_class = CompressOp
self.op = CompressOp()
super(CompressTester, self).setUp()
self.op = compress
def test_compressOp(self):
x = T.dmatrix()
cond = T.dvector()
def test_op(self):
for axis, cond, shape in zip(self.axis_list, self.cond_list, self.shape_list):
cond_var = theano.tensor.ivector()
data = numpy.random.random(size=shape).astype(theano.config.floatX)
data_var = tensor.TensorType(theano.config.floatX, [False]*2)()
cond_val = np.array([1, 0, 1, 0], dtype=bool)
a = np.random.random((3, 4)).astype(config.floatX)
f = theano.function([cond_var, data_var], self.op(cond_var, data_var, axis=axis))
f = theano.function([cond, x], compress(cond, x))
assert np.allclose(np.compress(cond_val, a), f(cond_val, a))
expected = numpy.compress(cond, data, axis=axis)
tested = f(cond, data)
for axis in range(len(a.shape)):
g = theano.function([cond, x], compress(cond, x, axis=axis))
assert np.allclose(np.compress(cond_val, a, axis=axis), g(cond_val, a))
assert tested.shape == expected.shape
assert numpy.allclose(tested, expected)
def test_infer_shape(self):
x = T.dmatrix()
cond = T.dvector()
cond_val = np.array([1, 0, 1, 0], dtype=bool)
a = np.random.random((3, 4)).astype(config.floatX)
self._compile_and_check([cond, x],
[compress(cond, x)],
[cond_val, a],
self.op_class)
for axis in range(len(a.shape)):
self._compile_and_check([cond, x],
[compress(cond, x, axis=axis)],
[cond_val, a],
self.op_class)
for axis, cond, shape in zip(self.axis_list, self.cond_list, self.shape_list):
cond_var = theano.tensor.ivector()
data = numpy.random.random(size=shape).astype(theano.config.floatX)
data_var = tensor.TensorType(theano.config.floatX, [False]*2)()
self._compile_and_check([cond_var, data_var],
[self.op(cond_var, data_var, axis=axis)],
[cond, data],
tensor.AdvancedSubtensor1,
warn=False)
class TestRepeatOp(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论