提交 c74b0283 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

add flatnonzero and tests

上级 8f8b6e1f
......@@ -3229,6 +3229,29 @@ class Nonzero(gof.Op):
nonzero = Nonzero()
def flatnonzero(a):
"""
Return indices that are non-zero in the flattened version of a.
This is equivalent to a.flatten().nonzero().
Parameters
----------
a : tensor
Input tensor
Returns
-------
res : vector
Output vector, containing the indices of the elements of `a.flatten()`
that are non-zero.
See Also
--------
nonzero : Return the indices of the non-zero elements of the input array.
"""
return nonzero(a.flatten())
class Tri(gof.Op):
def __init__(self, dtype=None):
......
......@@ -41,7 +41,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
itensor3, Tile, AdvancedIncSubtensor, switch, Diagonal, Diag,
nonzero)
nonzero, flatnonzero)
from theano.tests import unittest_tools as utt
from theano.printing import debugprint
......@@ -1926,6 +1926,30 @@ def test_nonzero():
rand4d[rand4d > rand4d.mean()] = 0
check(rand4d)
def test_flatnonzero():
def check(m):
m_symb = theano.tensor.tensor(dtype=m.dtype,
broadcastable = (False,) * m.ndim)
f = function([m_symb], flatnonzero(m_symb))
result = f(m)
assert numpy.allclose(result, numpy.flatnonzero(m))
rand1d = rand(8)
rand1d[rand1d > rand1d.mean()] = 0
check(rand1d)
rand2d = rand(8, 9)
rand2d[rand2d > rand2d.mean()] = 0
check(rand2d)
rand3d = rand(8, 9, 10)
rand3d[rand3d > rand3d.mean()] = 0
check(rand3d)
rand4d = rand(8, 9, 10, 11)
rand4d[rand4d > rand4d.mean()] = 0
check(rand4d)
def test_identity():
def check(dtype):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论