提交 41509b28 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make sure we do not use the sorted GpuTopK option.

上级 625cbf7a
...@@ -21,8 +21,10 @@ except ImportError as e: ...@@ -21,8 +21,10 @@ except ImportError as e:
# TODO GPU sort / argsort # TODO GPU sort / argsort
class GpuTopKOp(GpuKernelBase, TopKOp): class GpuTopKOp(GpuKernelBase, TopKOp):
''' '''Implements TopKOp on gpu
Implements TopKOp on gpu
Currently the output seem sorted, but we do not test it. So as on
the CPU, we only support sorted=False for now.
''' '''
__props__ = TopKOp.__props__ __props__ = TopKOp.__props__
...@@ -35,6 +37,9 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -35,6 +37,9 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
return_values=True, return_values=True,
return_indices=True return_indices=True
): ):
if sorted:
raise NotImplementedError(
"GpuTopK currently is not sure to give sorted output even if they look sorted..")
GpuKernelBase.__init__(self) GpuKernelBase.__init__(self)
TopKOp.__init__( TopKOp.__init__(
self, axis=axis, self, axis=axis,
...@@ -334,7 +339,8 @@ def local_gpua_topkop(op, ctx_name, inputs, outputs): ...@@ -334,7 +339,8 @@ def local_gpua_topkop(op, ctx_name, inputs, outputs):
ri = op.return_indices ri = op.return_indices
x, k = inputs x, k = inputs
x = as_gpuarray_variable(x, ctx_name) x = as_gpuarray_variable(x, ctx_name)
if op.sorted:
return
gpu_op = GpuTopKOp( gpu_op = GpuTopKOp(
axis=axis, axis=axis,
sorted=op.sorted, sorted=op.sorted,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论