提交 901a7a0b authored 作者: Adam Becker's avatar Adam Becker

add useless opt for topk

上级 d95bf06c
...@@ -35,6 +35,7 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, ...@@ -35,6 +35,7 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
advanced_subtensor, advanced_subtensor,
advanced_subtensor1, advanced_subtensor1,
advanced_inc_subtensor1) advanced_inc_subtensor1)
from theano.tensor.sort import TopKOp
from theano import scalar from theano import scalar
from theano.scalar import basic from theano.scalar import basic
from theano.tensor import basic as T from theano.tensor import basic as T
...@@ -7548,3 +7549,40 @@ def local_merge_alloc(node): ...@@ -7548,3 +7549,40 @@ def local_merge_alloc(node):
dim_outer, T.eq(dim_outer, dim_inner)) dim_outer, T.eq(dim_outer, dim_inner))
i += 1 i += 1
return [T.alloc(inputs_inner[0], *dims_outer)] return [T.alloc(inputs_inner[0], *dims_outer)]
@register_useless('fast_compile')
@gof.local_optimizer([TopKOp])
def local_useless_topk(node):
"""
TopKOp generates two outputs by default
This opt removes the useless ones
"""
op = node.op
if not isinstance(op, TopKOp):
return
if not (op.return_values and op.return_indices):
return False
x, k = node.inputs
ret_val = False
ret_idx = False
if op.return_values:
ret_val = bool(node.outputs[0].clients)
if op.return_indices:
ret_idx = bool(node.outputs[-1].clients)
if not (ret_val ^ ret_idx):
# both true -> nothing to remove
# both false -> let pruner handle
return False
old_output = node.outputs[ret_idx]
new_output = TopKOp(
axis=op.axis,
idx_dtype=op.idx_dtype,
return_values=ret_val,
return_indices=ret_idx)(x, k)[0]
return {old_output:new_output}
...@@ -413,7 +413,7 @@ class TopKOp(theano.Op): ...@@ -413,7 +413,7 @@ class TopKOp(theano.Op):
x, k = inputs x, k = inputs
k_grad = grad_undefined(self, 1, k, 'topk: k is not differentiable') k_grad = grad_undefined(self, 1, k, 'topk: k is not differentiable')
if not (self.return_indices and self.return_values): if not (self.return_indices or self.return_values):
x_grad = grad_undefined( x_grad = grad_undefined(
self, 0, x, 'topk: cannot get gradient' self, 0, x, 'topk: cannot get gradient'
' without both indices and values') ' without both indices and values')
...@@ -469,7 +469,7 @@ def topk(x, kth, axis=-1, sorted=True, idx_dtype='int64'): ...@@ -469,7 +469,7 @@ def topk(x, kth, axis=-1, sorted=True, idx_dtype='int64'):
raise NotImplementedError("sorted=True is not supported yet.") raise NotImplementedError("sorted=True is not supported yet.")
if axis is None: if axis is None:
x = theano.tensor.flatten(x) x = theano.tensor.flatten(x)
axis = -1 axis = 0
return TopKOp(axis=axis, idx_dtype=idx_dtype)(x, kth)[0] return TopKOp(axis=axis, idx_dtype=idx_dtype)(x, kth)[0]
......
...@@ -283,6 +283,9 @@ class Test_TopK(unittest.TestCase): ...@@ -283,6 +283,9 @@ class Test_TopK(unittest.TestCase):
x = theano.tensor.vector(name='x', dtype=dtype) x = theano.tensor.vector(name='x', dtype=dtype)
y = topk(x, k, sorted=sorted) y = topk(x, k, sorted=sorted)
fn = theano.function([x], y) fn = theano.function([x], y)
# assert local_useless_topk opt is done properly
assert 1 == len(fn.maker.fgraph.outputs[0].owner.outputs)
# generate a all-unique array # generate a all-unique array
xval = gen_unique_vector(size, dtype) xval = gen_unique_vector(size, dtype)
yval = fn(xval) yval = fn(xval)
...@@ -307,6 +310,10 @@ class Test_TopK(unittest.TestCase): ...@@ -307,6 +310,10 @@ class Test_TopK(unittest.TestCase):
x = theano.tensor.vector(name='x', dtype=dtype) x = theano.tensor.vector(name='x', dtype=dtype)
y = argtopk(x, k, sorted=sorted, idx_dtype=idx_dtype) y = argtopk(x, k, sorted=sorted, idx_dtype=idx_dtype)
fn = theano.function([x], y) fn = theano.function([x], y)
# assert local_useless_topk opt is done properly
assert 1 == len(fn.maker.fgraph.outputs[0].owner.outputs)
# generate a all-unique array # generate a all-unique array
xval = gen_unique_vector(size, dtype) xval = gen_unique_vector(size, dtype)
yval = fn(xval) yval = fn(xval)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论