提交 aa14269f authored 作者: Adam Becker's avatar Adam Becker

add back return_[values/indices]

because __props__ should be read only
上级 73197e70
...@@ -30,13 +30,17 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -30,13 +30,17 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
__props__ = TopKOp.__props__ __props__ = TopKOp.__props__
_f16_ok = True _f16_ok = True
def __init__(self, axis=-1, idx_dtype='int64'): def __init__(
self, axis=-1,
idx_dtype='int64',
return_values=True,
return_indices=True):
GpuKernelBase.__init__(self) GpuKernelBase.__init__(self)
TopKOp.__init__( TopKOp.__init__(
self, axis=axis, self, axis=axis,
idx_dtype=idx_dtype) idx_dtype=idx_dtype,
self.return_values = True return_values=return_values,
self.return_indices = True return_indices=return_indices)
def c_headers(self): def c_headers(self):
return ['gpuarray_api.h', 'gpuarray_helper.h', 'numpy_compat.h'] return ['gpuarray_api.h', 'gpuarray_helper.h', 'numpy_compat.h']
...@@ -291,8 +295,10 @@ def local_gpua_topkop(op, ctx_name, inputs, outputs): ...@@ -291,8 +295,10 @@ def local_gpua_topkop(op, ctx_name, inputs, outputs):
x, k = inputs x, k = inputs
x = as_gpuarray_variable(x, ctx_name) x = as_gpuarray_variable(x, ctx_name)
op = GpuTopKOp(axis=axis, idx_dtype=op.idx_dtype) gpu_op = GpuTopKOp(
op.return_values = rv axis=axis,
op.return_indices = ri idx_dtype=op.idx_dtype,
rets = op(x, k) return_values=rv,
return_indices=ri)
rets = gpu_op(x, k)
return rets return rets
...@@ -342,16 +342,23 @@ class TopKOp(theano.Op): ...@@ -342,16 +342,23 @@ class TopKOp(theano.Op):
def __init__( def __init__(
self, self,
axis=-1, axis=-1,
idx_dtype='int64'): idx_dtype='int64',
return_values=True,
return_indices=True
):
if not isinstance(axis, int): if not isinstance(axis, int):
raise TypeError( raise TypeError(
'"axis" parameter must be integer, got "%s"' % type(axis)) '"axis" parameter must be integer, got "%s"' % type(axis))
if idx_dtype not in theano.tensor.integer_dtypes: if idx_dtype not in theano.tensor.integer_dtypes:
raise TypeError( raise TypeError(
'"idx_dtype" parameter must be an integer dtype, got "%s"' % idx_dtype) '"idx_dtype" parameter must be an integer dtype, got "%s"' % idx_dtype)
if not (return_indices or return_values):
raise ValueError("Neither return_values nor return_indices is True, this isn't allowd")
self.axis = axis self.axis = axis
self.return_indices = True self.return_values = return_values
self.return_values = True self.return_indices = return_indices
self.idx_dtype = idx_dtype self.idx_dtype = idx_dtype
def __str__(self): def __str__(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论