提交 da44df32 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

raise error for scalar values

上级 4ac4e204
......@@ -3222,17 +3222,21 @@ class Nonzero(gof.Op):
"""
def make_node(self, a):
a = as_tensor_variable(a)
if a.ndim == 0:
raise ValueError('Nonzero only supports non-scalar arrays.')
output = [TensorType(dtype='int64', broadcastable=(False, False))()]
return gof.Apply(self, [a], output)
def perform(self, node, inp, out_):
a = inp[0]
out, = out_
result_tuple = numpy.nonzero(a)
if len(result_tuple) > 0 and len(result_tuple[0]) > 0:
if len(result_tuple[0]) > 0:
result = numpy.vstack(result_tuple)
else:
result = numpy.zeros((len(result_tuple), 0))
out[0] = result.astype('int64')
def grad(self, inp, grads):
......@@ -3306,6 +3310,8 @@ def flatnonzero(a):
nonzero : Return the indices of the non-zero elements of the input array.
nonzero_values : Return the non-zero elements of the input array
"""
if a.ndim == 0:
raise ValueError('Nonzero only supports non-scalar arrays.')
return nonzero(a.flatten(), return_matrix=True)[0]
def nonzero_values(a):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论