提交 3ca4c12f authored 作者: Adam Becker's avatar Adam Becker

fix crash

上级 a4533f5b
......@@ -30,13 +30,13 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
__props__ = TopKOp.__props__
_f16_ok = True
def __init__(self, axis=-1, return_values=True, return_indices=False, idx_dtype='int64'):
def __init__(self, axis=-1, idx_dtype='int64'):
GpuKernelBase.__init__(self)
TopKOp.__init__(
self, axis=axis,
return_values=return_values,
return_indices=return_indices,
idx_dtype=idx_dtype)
self.return_values = True
self.return_indices = True
def c_headers(self):
return ['gpuarray_api.h', 'gpuarray_helper.h', 'numpy_compat.h']
......@@ -291,9 +291,8 @@ def local_gpua_topkop(op, ctx_name, inputs, outputs):
x, k = inputs
x = as_gpuarray_variable(x, ctx_name)
rets = GpuTopKOp(
axis=axis,
return_values=rv,
return_indices=ri,
idx_dtype=op.idx_dtype)(x, k)
op = GpuTopKOp(axis=axis, idx_dtype=op.idx_dtype)
op.return_values = rv
op.return_indices = ri
rets = op(x, k)
return rets
from __future__ import absolute_import, print_function, division
import numpy as np
import theano
from theano.tensor.basic import mul, arange
......@@ -223,6 +222,8 @@ def argsort(a, axis=-1, kind='quicksort', order=None):
def _topk_py_impl(op, x, k, axis, idx_dtype):
ndim = x.ndim
assert -ndim <= axis < ndim
axis %= ndim
if k == 0:
raise ValueError('topk: k cannot be zero')
if abs(k) == 1:
......@@ -245,8 +246,7 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
fn_argmax(x, axis=axis), axis)
return zi.astype(idx_dtype)
asize = x.shape[axis]
if asize == abs(k):
if x.shape[axis] == abs(k):
if not op.return_indices:
return x.copy()
else:
......@@ -263,7 +263,7 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
return zi
idx = [slice(None)] * ndim
idx[axis] = slice(-k, None) if k > 0 else idx[axis] = slice(-k)
idx[axis] = (slice(-k, None) if k > 0 else slice(-k))
if not op.return_indices:
zv = np.partition(x, -k, axis=axis)[idx]
......@@ -336,7 +336,6 @@ class TopKOp(theano.Op):
# one result is needed
# TODO R_op
__props__ = ('axis', 'return_values', 'return_indices', 'idx_dtype')
def __init__(
......@@ -345,7 +344,7 @@ class TopKOp(theano.Op):
idx_dtype='int64'):
if not isinstance(axis, int):
raise TypeError(
'"axis" parameter must be integer, got "%s"' % type(self.axis))
'"axis" parameter must be integer, got "%s"' % type(axis))
if idx_dtype not in theano.tensor.integer_dtypes:
raise TypeError(
'"idx_dtype" parameter must be an integer dtype, got "%s"' % idx_dtype)
......@@ -382,10 +381,7 @@ class TopKOp(theano.Op):
def perform(self, node, inputs, output_storage):
x, k = inputs
ndim = x.ndim
axis = self.axis
assert -ndim <= axis < ndim
axis %= ndim
if not self.return_indices:
pzv = output_storage[0]
pzv[0] = _topk_py_impl(self, x, k, axis, None)
......@@ -401,7 +397,6 @@ class TopKOp(theano.Op):
def infer_shape(self, node, inp_shapes):
_check_tensor_is_scalar(node.inputs[1])
shp = list(inp_shapes[0])
ndim = node.inputs[0].ndim
shp[self.axis] = np.abs(node.inputs[1])
shp = tuple(shp)
return [shp for i in [self.return_values, self.return_indices] if i]
......
......@@ -22,6 +22,7 @@ def gen_unique_vector(size, dtype):
return (retval[np.random.permutation(size)] - size * 1.5).astype(dtype)
'''
class Test_sort(unittest.TestCase):
def setUp(self):
......@@ -33,7 +34,7 @@ class Test_sort(unittest.TestCase):
a = tensor.dmatrix()
w = sort(a)
f = theano.function([a], w)
assert utt.assert_allclose(f(self.m_val), np.sort(self.m_val))
utt.assert_allclose(f(self.m_val), np.sort(self.m_val))
def test2(self):
a = tensor.dmatrix()
......@@ -43,7 +44,7 @@ class Test_sort(unittest.TestCase):
for axis_val in 0, 1:
gv = f(self.m_val, axis_val)
gt = np.sort(self.m_val, axis_val)
assert utt.assert_allclose(gv, gt)
utt.assert_allclose(gv, gt)
def test3(self):
a = tensor.dvector()
......@@ -51,7 +52,7 @@ class Test_sort(unittest.TestCase):
f = theano.function([a], w2)
gv = f(self.v_val)
gt = np.sort(self.v_val)
assert utt.assert_allclose(gv, gt)
utt.assert_allclose(gv, gt)
def test4(self):
a = tensor.dmatrix()
......@@ -61,7 +62,7 @@ class Test_sort(unittest.TestCase):
for axis_val in 0, 1:
gv = f(self.m_val, axis_val)
gt = np.sort(self.m_val, axis_val)
assert utt.assert_allclose(gv, gt)
utt.assert_allclose(gv, gt)
def test5(self):
a1 = SortOp("mergesort", [])
......@@ -78,7 +79,7 @@ class Test_sort(unittest.TestCase):
f = theano.function([a], l)
gv = f(self.m_val)
gt = np.sort(self.m_val, None)
assert utt.assert_allclose(gv, gt)
utt.assert_allclose(gv, gt)
def test_grad_vector(self):
data = np.random.rand(10).astype(theano.config.floatX)
......@@ -170,7 +171,7 @@ def test_argsort():
f = theano.function([a], w)
gv = f(m_val)
gt = np.argsort(m_val)
assert utt.assert_allclose(gv, gt)
utt.assert_allclose(gv, gt)
# Example 2
a = tensor.dmatrix()
......@@ -180,7 +181,7 @@ def test_argsort():
for axis_val in 0, 1:
gv = f(m_val, axis_val)
gt = np.argsort(m_val, axis_val)
assert utt.assert_allclose(gv, gt)
utt.assert_allclose(gv, gt)
# Example 3
a = tensor.dvector()
......@@ -188,7 +189,7 @@ def test_argsort():
f = theano.function([a], w2)
gv = f(v_val)
gt = np.argsort(v_val)
assert utt.assert_allclose(gv, gt)
utt.assert_allclose(gv, gt)
# Example 4
a = tensor.dmatrix()
......@@ -198,7 +199,7 @@ def test_argsort():
for axis_val in 0, 1:
gv = f(m_val, axis_val)
gt = np.argsort(m_val, axis_val)
assert utt.assert_allclose(gv, gt)
utt.assert_allclose(gv, gt)
# Example 5
a = tensor.dmatrix()
......@@ -216,7 +217,7 @@ def test_argsort():
f = theano.function([a], w2)
gv = f(m_val)
gt = np.argsort(m_val, None)
assert utt.assert_allclose(gv, gt)
utt.assert_allclose(gv, gt)
def test_argsort_grad():
......@@ -229,6 +230,7 @@ def test_argsort_grad():
data = np.random.rand(2, 3, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=2), [data])
'''
class Test_TopK(unittest.TestCase):
......@@ -265,7 +267,7 @@ class Test_TopK(unittest.TestCase):
xval = np.asarray([1]).astype(dtype)
yvval, yival = fn(xval)
assert yival == np.asarray([0], dtype=idx_dtype)
assert utt.assert_allclose(xval, yvval)
utt.assert_allclose(xval, yvval)
assert yvval.dtype == xval.dtype
assert yival.dtype == np.dtype(idx_dtype)
......@@ -285,11 +287,11 @@ class Test_TopK(unittest.TestCase):
# generate a all-unique array
xval = gen_unique_vector(size, dtype)
yval = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
idx = (slice(-k, None) if k > 0 else slice(-k))
goal = np.sort(xval)[idx]
assert yval.dtype == goal.dtype
assert utt.assert_allclose(np.sort(yval), goal)
utt.assert_allclose(np.sort(yval), goal)
@utt.parameterized.expand(chain(
product(
......@@ -308,7 +310,7 @@ class Test_TopK(unittest.TestCase):
# generate a all-unique array
xval = gen_unique_vector(size, dtype)
yval = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
idx = (slice(-k, None) if k > 0 else slice(-k))
goal = np.argsort(xval)[idx].astype(idx_dtype)
# due to uniqueness, we expect indices same
......@@ -331,13 +333,13 @@ class Test_TopK(unittest.TestCase):
# generate a all-unique array
xval = gen_unique_vector(size, dtype)
yvval, yival = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
idx = (slice(-k, None) if k > 0 else slice(-k))
goali = np.argsort(xval)[idx].astype(idx_dtype)
goalv = xval[goali]
# due to uniqueness, we expect indices same
assert np.all(xval[np.sort(yival)] == xval[np.sort(goali)])
assert utt.assert_allclose(np.sort(yvval), goalv)
utt.assert_allclose(np.sort(yvval), goalv)
@utt.parameterized.expand(chain(
product(
......@@ -356,11 +358,9 @@ class Test_TopK(unittest.TestCase):
xval = np.repeat(np.random.uniform(-100., 100., size=size // 2).astype(dtype), 2)
xval = xval[np.random.permutation(size)]
yval = fn(xval)
idx = slice(-k, None) if k > 0 else slice(-k)
idx = (slice(-k, None) if k > 0 else slice(-k))
goal = np.argsort(xval)[idx].astype('int32')
print(goal)
print(np.argsort(xval))
assert utt.assert_allclose(np.sort(xval[yval]), np.sort(xval[goal]))
utt.assert_allclose(np.sort(xval[yval]), np.sort(xval[goal]))
@utt.parameterized.expand(product(
((17, 15), (2, 3, 5, 7, 11), (2017, 5, 3)),
......@@ -391,14 +391,6 @@ class Test_TopK(unittest.TestCase):
idx = (slice(None),) * l + (idx,) + (slice(None),) * (r - 1)
goal = np.argsort(xval, axis=axis)[idx].astype(idx_dtype)
print(dict(k=k, axis=axis, shp=shp))
print('x:')
print(xval)
print('y:')
print(np.sort(yval, axis=axis))
print('goal:')
print(np.sort(goal, axis=axis))
# print(np.argsort(xval))
assert np.all(np.sort(yval, axis=axis) == np.sort(goal, axis=axis))
@utt.parameterized.expand(product(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论